Skip to content

Commit 67549f2

Browse files
authored
[gkd] support gkd use_logits_to_keep & padding_free & packing (#4658)
1 parent 871278c commit 67549f2

File tree

17 files changed

+191
-32
lines changed

17 files changed

+191
-32
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ You can contact us and communicate with us by adding our group:
7575

7676
## 🎉 News
7777
- 🎁 2025.06.18: Support for accelerating the ms-swift [inference](https://github.com/modelscope/ms-swift/blob/main/examples/infer/sglang), deployment, evaluation, and UI modules using the [sglang](https://github.com/sgl-project/sglang) inference acceleration engine. Simply set `--infer_backend sglang` to enable it.
78-
- 🎁 2025.06.15: Support for GKD training on both pure text large models and multimodal models. Training scripts can be found here: [Pure Text](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh), [Multimodal](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh).
78+
- 🎁 2025.06.15: Support for GKD training on both pure text large models and multimodal models. Training scripts can be found here: [Pure Text](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd), [Multimodal](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd).
7979
- 🎁 2025.06.11: Support for using Megatron parallelism techniques for RLHF training. The training script can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf).
8080
- 🎁 2025.05.29: Support sequence parallel in pt, sft, dpo and grpo, check script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text).
8181
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models).
@@ -288,7 +288,7 @@ Supported Training Methods:
288288
| GRPO Training | [](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal) |||| [](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/external) ||
289289
| Reward Model Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) |||
290290
| PPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) |||
291-
| GKD Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh) |
291+
| GKD Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd) |
292292
| KTO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
293293
| CPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) |||
294294
| SimPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) |||

README_CN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171

7272
## 🎉 新闻
7373
- 🎁 2025.06.18: 支持使用[sglang](https://github.com/sgl-project/sglang)推理加速引擎对ms-swift[推理](https://github.com/modelscope/ms-swift/blob/main/examples/infer/sglang)/部署/评测/ui模块进行加速,设置`--infer_backend sglang`即可。
74-
- 🎁 2025.06.15: 支持对纯文本大模型和多模态模型进行GKD训练。训练脚本参考这里:[纯文本](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh), [多模态](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh)
74+
- 🎁 2025.06.15: 支持对纯文本大模型和多模态模型进行GKD训练。训练脚本参考这里:[纯文本](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd), [多模态](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd)
7575
- 🎁 2025.06.11: 支持使用Megatron并行技术进行RLHF训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf)
7676
- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text)
7777
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
@@ -277,7 +277,7 @@ print(f'response: {resp_list[0].choices[0].message.content}')
277277
| GRPO训练 | [](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal) |||| [](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/external) ||
278278
| 奖励模型训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) |||
279279
| PPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) |||
280-
| GKD训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh) |
280+
| GKD训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd) |
281281
| KTO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
282282
| CPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) |||
283283
| SimPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) |||

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,8 @@ RLHF参数继承于[训练参数](#训练参数)。
407407
- undesirable_weight: KTO算法中对undesirable response的loss权重 $\lambda_U$,默认为`1.`
408408
- loss_scale: 覆盖模板参数,默认为'last_round'。
409409
- temperature: 默认为0.9,该参数将在PPO、GRPO、GKD中使用。
410-
- lmbda: 默认为0.5。该参数在GKD中使用。控制学生数据比例的 lambda 参数(即策略内学生生成输出所占的比例)。
410+
- lmbda: 默认为0.5。该参数在GKD中使用。控制学生数据比例的 lambda 参数(即策略内学生生成输出所占的比例)。若lmbda为0,则不使用学生生成数据。
411+
- sft_alpha: 默认为0。控制GKD中加入sft_loss的权重。最后的loss为`gkd_loss + sft_alpha * sft_loss`
411412
- seq_kd: 默认为False。该参数在GKD中使用。控制是否执行序列级知识蒸馏(Sequence-Level KD)的 seq_kd 参数(可视为对教师模型生成输出的监督式微调)。
412413
- 注意:你可以提前对数据集内容使用teacher模型进行推理(使用vllm/sglang/lmdeploy等推理引擎加速),并在训练时将`seq_kd`设置为False。或者将`seq_kd`设置为True,在训练时使用teacher模型生成序列(能保证多个epoch生成数据的不同,但效率较慢)。
413414

docs/source/Instruction/预训练与微调.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
| GRPO训练 | [](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal) |||| [](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/external) ||
1111
| 奖励模型训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) |||
1212
| PPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) |||
13-
| GKD训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh) |
13+
| GKD训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd) |
1414
| KTO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
1515
| CPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) |||
1616
| SimPO训练 || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) |||

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ RLHF arguments inherit from the [training arguments](#training-arguments).
417417
- undesirable_weight: Loss weight $\lambda_U$ for undesirable response in the KTO algorithm, default is `1.`.
418418
- loss_scale: Override template arguments, default is 'last_round'.
419419
- temperature: Default is 0.9; this parameter will be used in PPO, GRPO and GKD.
420-
- lmbda: Default is 0.5. This parameter is used in GKD. It is the lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs).
420+
- lmbda: Default is 0.5. This parameter is used in GKD. It controls the lambda parameter for the proportion of student data (i.e., the proportion of student-generated outputs within the strategy). If lmbda is 0, student-generated data is not used.
421+
- sft_alpha: The default value is 0. It controls the weight of sft_loss added in GKD. The final loss is `gkd_loss + sft_alpha * sft_loss`.
421422
- seq_kd: Default is False. This parameter is used in GKD. It is the `seq_kd` parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised fine-tuning on teacher-generated output).
422423
- Note: You can perform inference on the dataset using the teacher model in advance (accelerated by inference engines such as vLLM, SGLang, or lmdeploy), and set `seq_kd` to False during training. Alternatively, you can set `seq_kd` to True, which will use the teacher model to generate sequences during training (ensuring different generated data across multiple epochs, but at a slower efficiency).
423424

docs/source_en/Instruction/Pre-training-and-Fine-tuning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Training Capability:
99
| GRPO Training | [](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal) |||| [](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/external) ||
1010
| Reward Model Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) |||
1111
| PPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo) |||
12-
| GKD Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd.sh) |
12+
| GKD Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/gkd) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd) |
1313
| KTO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
1414
| CPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) |||
1515
| SimPO Training || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) || [](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) |||
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
export teacher_model='OpenGVLab/InternVL3-8B'
3+
4+
NPROC_PER_NODE=4 \
5+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
6+
swift infer \
7+
--model $teacher_model \
8+
--infer_backend vllm \
9+
--val_dataset 'modelscope/coco_2014_caption:validation#5000' \
10+
--gpu_memory_utilization 0.9 \
11+
--max_model_len 8192 \
12+
--max_new_tokens 2048 \
13+
--write_batch_size 1000 \
14+
--result_path new_coco_dataset.jsonl
15+
16+
17+
# 4 * 42GiB, 3.05s/it
18+
NPROC_PER_NODE=4 \
19+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
20+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
21+
swift rlhf \
22+
--rlhf_type gkd \
23+
--model OpenGVLab/InternVL3-2B-Pretrained \
24+
--teacher_model $teacher_model \
25+
--train_type full \
26+
--dataset 'new_coco_dataset.jsonl' \
27+
--torch_dtype bfloat16 \
28+
--num_train_epochs 1 \
29+
--per_device_train_batch_size 4 \
30+
--per_device_eval_batch_size 4 \
31+
--learning_rate 1e-5 \
32+
--freeze_vit true \
33+
--gradient_accumulation_steps 1 \
34+
--eval_steps 100 \
35+
--save_steps 100 \
36+
--save_total_limit 2 \
37+
--logging_steps 5 \
38+
--max_length 4096 \
39+
--output_dir output \
40+
--warmup_ratio 0.05 \
41+
--save_only_model true \
42+
--dataloader_num_workers 4 \
43+
--dataset_num_proc 4 \
44+
--deepspeed zero2 \
45+
--padding_free true \
46+
--attn_impl flash_attn \
47+
--lmbda 0

examples/train/multimodal/rlhf/gkd.sh renamed to examples/train/multimodal/rlhf/gkd/full.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 4 * 45GiB
1+
# 4 * 45GiB, 10.29s/it
22
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
33
CUDA_VISIBLE_DEVICES=0,1,2,3 \
44
MASTER_PORT=29501 \

0 commit comments

Comments
 (0)