Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions roll/pipeline/rlvr/rlvr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@ class DatasetFilterConfig:
max_difficulty: Optional[float] = None
num_samples: int = 0


@dataclass
class VLMFilterConfig:
"""Configuration for filtering overlong prompts in VLM pipeline.

Similar to verl's RLHFDataset config options.
"""
enable: bool = field(
default=True,
metadata={"help": "Whether to filter out samples exceeding max_prompt_length."}
)
num_workers: Optional[int] = field(
default=None,
metadata={"help": "Number of workers for parallel filtering. Defaults to max(1, cpu_count // 4)."}
)
prompt_key: str = field(
default="prompt",
metadata={"help": "Key for the prompt text in the dataset."}
)
image_key: str = field(
default="images",
metadata={"help": "Key for the images in the dataset."}
)
image_flag_key: str = field(
default="image_flag",
metadata={"help": "Key for the image flag indicating valid images."}
)


@dataclass
class RewardFilterConfig:
type: Literal["no_filter", "mean_filter", "std_filter"] = field(
Expand Down Expand Up @@ -88,6 +117,10 @@ class RLVRConfig(PPOConfig):
default_factory=DatasetFilterConfig,
metadata={"help": "Configuration for filtering dataset by source and difficulty"},
)
vlm_filter: VLMFilterConfig = field(
default_factory=VLMFilterConfig,
metadata={"help": "Configuration for filtering overlong prompts in VLM pipeline"},
)
num_return_sequences_in_group: int = field(
default=1,
metadata={"help": "The number of return sequences in one group, used in generation_args."}
Expand Down
192 changes: 174 additions & 18 deletions roll/pipeline/rlvr/rlvr_vlm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


def format_prompt(prompt, processor, use_image=True, prompt_image_token=None):
question_template = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
question_template = "{Question}"
if isinstance(prompt, list):
messages = prompt
else:
Expand Down Expand Up @@ -162,12 +162,123 @@ def encode_function(
return encodings


def get_vlm_dataset(data_args, encode_function, processor, get_eval=False):
def filter_overlong_prompts(
dataset: datasets.Dataset,
processor: ProcessorMixin,
max_prompt_length: int,
prompt_key: str = "prompt",
image_key: str = "images",
image_flag_key: str = "image_flag",
num_workers: Optional[int] = None,
) -> datasets.Dataset:
"""Filter out samples where text + image tokens exceed max_prompt_length.

Note: Returns a new filtered dataset; the original dataset is not modified.

Args:
dataset: The dataset to filter (already encoded with formatted prompts and processed images).
processor: The processor to use for tokenization.
max_prompt_length: Maximum allowed token length (inclusive).
prompt_key: Key for the prompt text in the dataset.
image_key: Key for the images in the dataset.
image_flag_key: Key for the image flag indicating valid images.
num_workers: Number of processes for parallel filtering. Defaults to max(1, cpu_count // 4).

Returns:
Filtered dataset with samples within the token length limit.
"""
# Default num_workers similar to verl's approach
if num_workers is None:
num_workers = max(1, (os.cpu_count() or 4) // 4)

original_len = len(dataset)
if original_len == 0:
logger.info("Dataset is empty, skipping filtering")
return dataset

def compute_token_count(example) -> int:
"""Compute token count for a sample. Returns max_prompt_length + 1 on parse failure."""
try:
# Prompt is already formatted by encode_function, use directly
prompt = example[prompt_key]
if not isinstance(prompt, str):
# Fallback: apply chat template only if prompt is not already a string
prompt = processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)

# Use already-processed images from encode_function
images = None
if example.get(image_flag_key) and example.get(image_key):
images = example[image_key]
# Handle single image or list of images
if not isinstance(images, list):
images = [images]

# Calculate token length (same path as DataCollatorWithPaddingForMM)
inputs = processor(text=prompt, images=images)
input_ids = inputs["input_ids"][0]

return len(input_ids)

except Exception as e:
# Return max_prompt_length + 1 on any error to filter out the sample
logger.error(f"Error processing sample during filter, skipping: {e}")
return max_prompt_length + 1

filtered_dataset = dataset.filter(
lambda example: compute_token_count(example) <= max_prompt_length,
num_proc=num_workers,
desc=f"Keeping prompts with length <= {max_prompt_length} tokens",
)

filtered_len = len(filtered_dataset)
filtered_count = original_len - filtered_len
if filtered_count > 0:
logger.info(
f"Filtered {filtered_count}/{original_len} samples ({100 * filtered_count / original_len:.1f}%) "
f"exceeding max_prompt_length={max_prompt_length}. Remaining: {filtered_len}"
)
else:
logger.info(f"All {original_len} samples within max_prompt_length={max_prompt_length}")

return filtered_dataset


def get_vlm_dataset(
data_args,
encode_function,
processor,
get_eval=False,
max_prompt_length: Optional[int] = None,
vlm_filter: Optional["VLMFilterConfig"] = None,
):
"""Load and encode VLM dataset with optional filtering.

Args:
data_args: Data arguments containing dataset path and preprocessing settings.
encode_function: Function to encode the dataset.
processor: Processor for tokenization and image processing.
get_eval: Whether to load evaluation dataset.
max_prompt_length: Maximum prompt length for filtering. If None, filtering is disabled.
vlm_filter: VLMFilterConfig for filtering settings. If None, uses default values.
"""
# Import here to avoid circular import
from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig

if vlm_filter is None:
vlm_filter = VLMFilterConfig()

cache_path = getattr(data_args, "cache_path", None)
if cache_path:
cache_path = os.path.join(cache_path, "val" if get_eval else "train")
if cache_path and os.path.exists(cache_path):

# When filtering is enabled with max_prompt_length, skip cache to ensure correct filtering.
# The cached dataset may have been filtered with a different max_prompt_length.
# Filtering is always done after loading/encoding, and we don't cache filtered results.
should_filter = vlm_filter.enable and max_prompt_length is not None
use_cache = cache_path and os.path.exists(cache_path) and not should_filter
if use_cache:
dataset = load_from_disk(cache_path)
logger.info(f"Loaded dataset from cache: {cache_path}")
return dataset

dataset = get_dataset(data_args=data_args)
Expand Down Expand Up @@ -202,8 +313,31 @@ def get_vlm_dataset(data_args, encode_function, processor, get_eval=False):
desc="Encoding dataset",
)
print(f"Encoding: {dataset}")
if cache_path:

# Filter out samples where text + image tokens exceed max_prompt_length
# This is done AFTER encoding but BEFORE saving to cache
# Uses vlm_filter config for filtering settings
if should_filter:
dataset = filter_overlong_prompts(
dataset=dataset,
processor=processor,
max_prompt_length=max_prompt_length,
prompt_key=vlm_filter.prompt_key,
image_key=vlm_filter.image_key,
image_flag_key=vlm_filter.image_flag_key,
num_workers=vlm_filter.num_workers,
)

# Only cache if no filtering was applied
# This prevents caching filtered datasets which may have different max_prompt_length
if cache_path and not should_filter:
dataset.save_to_disk(cache_path)
logger.info(f"Saved dataset to cache: {cache_path}")
elif cache_path and should_filter:
logger.info(
f"Skipping cache save because max_prompt_length={max_prompt_length} filtering was applied. "
"Next run will re-encode and re-filter with the same max_prompt_length."
)
return dataset


Expand All @@ -222,7 +356,12 @@ def __init__(self, pipeline_config: RLVRConfig):
self.tokenizer.padding_side = "left"

dataset = get_vlm_dataset(
self.pipeline_config.actor_train.data_args, encode_function, self.processor, get_eval=False
self.pipeline_config.actor_train.data_args,
encode_function,
self.processor,
get_eval=False,
max_prompt_length=self.pipeline_config.prompt_length,
vlm_filter=self.pipeline_config.vlm_filter,
)
# update domain field, DynamicSamplingScheduler requires
dataset = dataset.map(
Expand All @@ -244,7 +383,12 @@ def __init__(self, pipeline_config: RLVRConfig):
self.val_dataset = None
if self.pipeline_config.validation and self.pipeline_config.validation.data_args:
self.val_dataset = get_vlm_dataset(
self.pipeline_config.validation.data_args, encode_function, self.processor, get_eval=True
self.pipeline_config.validation.data_args,
encode_function,
self.processor,
get_eval=True,
max_prompt_length=self.pipeline_config.prompt_length,
vlm_filter=self.pipeline_config.vlm_filter,
)
self.val_dataset = self.val_dataset.map(
partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
Expand Down Expand Up @@ -316,9 +460,11 @@ def __init__(self, pipeline_config: RLVRConfig):
else:
domain_batch_size = int(domain_ratios[domain] * self.pipeline_config.rollout_batch_size)
accumulated += domain_batch_size
generate_scheduler = ray.remote(DynamicSamplingScheduler).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
generate_scheduler = (
ray.remote(DynamicSamplingScheduler).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
)
)
).remote(
pipeline_config=self.pipeline_config,
Expand Down Expand Up @@ -353,9 +499,11 @@ def __init__(self, pipeline_config: RLVRConfig):
if self.val_dataset:
val_pipeline_config = copy.deepcopy(self.pipeline_config)
val_pipeline_config.is_use_additional_prompts = False
self.val_generate_scheduler = ray.remote(DynamicSamplingScheduler).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
self.val_generate_scheduler = (
ray.remote(DynamicSamplingScheduler).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
)
)
).remote(
pipeline_config=val_pipeline_config,
Expand Down Expand Up @@ -494,9 +642,11 @@ def run(self):
metrics_mgr.add_metric("time/val_step", val_step_timer.last)

# 要按domain group by生成对应的batch
with actor_infer_timer, actor_infer_response_timer, Timer(
name="step_generate", logger=None
) as step_generate_timer:
with (
actor_infer_timer,
actor_infer_response_timer,
Timer(name="step_generate", logger=None) as step_generate_timer,
):
domain_batches = {}
scheduler_refs = {}
for domain, scheduler in self.generate_schedulers.items():
Expand Down Expand Up @@ -527,7 +677,9 @@ def run(self):
batch.meta_info["_broadcast_non_tensor_batch"] = True
batch.meta_info["loss_mask_keys"] = ["response_mask", "final_response_mask"]

batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object)
batch.non_tensor_batch["sample_uuid"] = np.array(
[str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object
)
with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer:
if self.pipeline_config.enable_reference:
ref_log_probs = self.reference.compute_log_probs(batch, blocking=True)
Expand All @@ -542,7 +694,9 @@ def run(self):
values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False)

if self.pipeline_config.enable_old_logprobs_recompute:
old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False)
old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(
batch, blocking=False
)
old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs)
agg_entropy = agg_loss(
loss_mat=old_log_probs.batch["entropy"],
Expand Down Expand Up @@ -723,7 +877,9 @@ def val(self, global_step):
{"global_step": self.global_step, "max_steps": self.pipeline_config.max_steps, "is_training": False}
)
generate_output: DataProto = ray.get(
self.val_generate_scheduler.get_batch.remote(data=batch, global_step=global_step, batch_size=len(self.val_dataset)),
self.val_generate_scheduler.get_batch.remote(
data=batch, global_step=global_step, batch_size=len(self.val_dataset)
),
timeout=self.pipeline_config.rpc_timeout,
)
generate_output.meta_info.pop("is_offload_states", None)
Expand Down
Loading