Skip to content

fix: correct CriticWorker tqdm total to match epochs=1 dataloader#377

Open
dubin555 wants to merge 1 commit intoalibaba:mainfrom
dubin555:oss-scout/verify-critic-tqdm-total-mismatch
Open

fix: correct CriticWorker tqdm total to match epochs=1 dataloader#377
dubin555 wants to merge 1 commit intoalibaba:mainfrom
dubin555:oss-scout/verify-critic-tqdm-total-mismatch

Conversation

@dubin555
Copy link

Problem

In CriticWorker.train_step (roll/pipeline/base_worker.py), the dataloader is created with epochs=1, but the tqdm progress bar total is calculated using ppo_epochs:

dataloader = data.make_iterator(
    mini_batch_size=backward_batch_size, epochs=1  # epochs=1
)

for batch_idx, mini_batch in tqdm(
    enumerate(dataloader),
    total=data.batch.batch_size[0] * ppo_epochs // backward_batch_size  # uses ppo_epochs
):

When ppo_epochs > 1, the progress bar total is inflated. For example with ppo_epochs=2, the progress bar shows 50% when training is actually complete.

Fix

Remove ppo_epochs from the tqdm total calculation to match the epochs=1 dataloader:

total=data.batch.batch_size[0] // backward_batch_size

Files Changed

  • roll/pipeline/base_worker.py — tqdm total calculation
  • tests/pipeline/test_critic_tqdm_total.py — regression test

CriticWorker.train_step creates the dataloader with epochs=1, but the
tqdm progress bar total was calculated using ppo_epochs. This caused the
progress bar to never reach 100% when ppo_epochs > 1 (e.g. showing 50%
completion with ppo_epochs=2 when training was actually done).

Remove the ppo_epochs multiplier from the tqdm total so it matches the
actual number of iterations produced by the dataloader.
@CLAassistant
Copy link

CLAassistant commented Mar 14, 2026

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants