Skip to content

Conversation

@freeliuzc
Copy link
Collaborator

Motivation

  1. The kernel treats bad_token as a list, but gpu_model_runner initializes it as a 2D tensor with shape [batch_size, vocab_size]
  2. As a result, all queries end up accessing the bad_tokens of batch 0.

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings January 16, 2026 08:32
@paddle-bot
Copy link

paddle-bot bot commented Jan 16, 2026

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

这个PR修复了token penalty kernel中的一个关键bug。在原代码中,ban_bad_words内核错误地将所有批次的查询都访问batch 0的bad_tokens数据,而不是使用各自批次的bad_tokens。修复方法是在内核中正确计算每个批次的bad_tokens指针偏移量。

Changes:

  • 修复了ban_bad_words内核函数,使其正确访问各批次对应的bad_tokens数据
  • 统一了代码格式,将缩进从4空格改为2空格(符合项目的clang-format配置)
  • 将参数名从bad_words_list更改为bad_tokens以保持命名一致性

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
custom_ops/gpu_ops/token_penalty_multi_scores.cu 修复了ban_bad_words内核中的批次索引bug,并统一代码格式
custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu 在speculative decoding版本中应用了相同的bug修复

const int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
const int64_t bad_tokens_now = bad_tokens + bi * bad_words_len;
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型错误:bad_tokens_now应该声明为const int64_t *(指针类型),而不是int64_t(整数类型)。这会导致编译错误,因为后续代码试图将其作为数组访问(bad_tokens_now[i])。正确的声明应该是:const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_len;

Suggested change
const int64_t bad_tokens_now = bad_tokens + bi * bad_words_len;
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_len;

Copilot uses AI. Check for mistakes.
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length;
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件中的实现是正确的(使用了指针类型),与另一个文件(token_penalty_multi_scores.cu line 137)的实现不一致。请确保两个文件中的实现保持一致,都使用正确的指针类型声明。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant