diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py
index 43b8ffccb..df2e91d28 100644
--- a/roll/pipeline/rlvr/rlvr_config.py
+++ b/roll/pipeline/rlvr/rlvr_config.py
@@ -17,6 +17,35 @@ class DatasetFilterConfig:
max_difficulty: Optional[float] = None
num_samples: int = 0
+
+@dataclass
+class VLMFilterConfig:
+ """Configuration for filtering overlong prompts in VLM pipeline.
+
+ Similar to verl's RLHFDataset config options.
+ """
+ enable: bool = field(
+ default=True,
+ metadata={"help": "Whether to filter out samples exceeding max_prompt_length."}
+ )
+ num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of workers for parallel filtering. Defaults to max(1, cpu_count // 4)."}
+ )
+ prompt_key: str = field(
+ default="prompt",
+ metadata={"help": "Key for the prompt text in the dataset."}
+ )
+ image_key: str = field(
+ default="images",
+ metadata={"help": "Key for the images in the dataset."}
+ )
+ image_flag_key: str = field(
+ default="image_flag",
+ metadata={"help": "Key for the image flag indicating valid images."}
+ )
+
+
@dataclass
class RewardFilterConfig:
type: Literal["no_filter", "mean_filter", "std_filter"] = field(
@@ -88,6 +117,10 @@ class RLVRConfig(PPOConfig):
default_factory=DatasetFilterConfig,
metadata={"help": "Configuration for filtering dataset by source and difficulty"},
)
+ vlm_filter: VLMFilterConfig = field(
+ default_factory=VLMFilterConfig,
+ metadata={"help": "Configuration for filtering overlong prompts in VLM pipeline"},
+ )
num_return_sequences_in_group: int = field(
default=1,
metadata={"help": "The number of return sequences in one group, used in generation_args."}
diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py
index dc7af456c..45e657558 100644
--- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py
+++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py
@@ -51,7 +51,7 @@
def format_prompt(prompt, processor, use_image=True, prompt_image_token=None):
- question_template = "{Question} Output the thinking process in and final answer (number) in tags."
+ question_template = "{Question}"
if isinstance(prompt, list):
messages = prompt
else:
@@ -162,12 +162,123 @@ def encode_function(
return encodings
-def get_vlm_dataset(data_args, encode_function, processor, get_eval=False):
+def filter_overlong_prompts(
+ dataset: datasets.Dataset,
+ processor: ProcessorMixin,
+ max_prompt_length: int,
+ prompt_key: str = "prompt",
+ image_key: str = "images",
+ image_flag_key: str = "image_flag",
+ num_workers: Optional[int] = None,
+) -> datasets.Dataset:
+ """Filter out samples where text + image tokens exceed max_prompt_length.
+
+ Note: Returns a new filtered dataset; the original dataset is not modified.
+
+ Args:
+ dataset: The dataset to filter (already encoded with formatted prompts and processed images).
+ processor: The processor to use for tokenization.
+ max_prompt_length: Maximum allowed token length (inclusive).
+ prompt_key: Key for the prompt text in the dataset.
+ image_key: Key for the images in the dataset.
+ image_flag_key: Key for the image flag indicating valid images.
+ num_workers: Number of processes for parallel filtering. Defaults to max(1, cpu_count // 4).
+
+ Returns:
+ Filtered dataset with samples within the token length limit.
+ """
+ # Default num_workers similar to verl's approach
+ if num_workers is None:
+ num_workers = max(1, (os.cpu_count() or 4) // 4)
+
+ original_len = len(dataset)
+ if original_len == 0:
+ logger.info("Dataset is empty, skipping filtering")
+ return dataset
+
+ def compute_token_count(example) -> int:
+ """Compute token count for a sample. Returns max_prompt_length + 1 on parse failure."""
+ try:
+ # Prompt is already formatted by encode_function, use directly
+ prompt = example[prompt_key]
+ if not isinstance(prompt, str):
+ # Fallback: apply chat template only if prompt is not already a string
+ prompt = processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
+
+ # Use already-processed images from encode_function
+ images = None
+ if example.get(image_flag_key) and example.get(image_key):
+ images = example[image_key]
+ # Handle single image or list of images
+ if not isinstance(images, list):
+ images = [images]
+
+ # Calculate token length (same path as DataCollatorWithPaddingForMM)
+ inputs = processor(text=prompt, images=images)
+ input_ids = inputs["input_ids"][0]
+
+ return len(input_ids)
+
+ except Exception as e:
+ # Return max_prompt_length + 1 on any error to filter out the sample
+ logger.error(f"Error processing sample during filter, skipping: {e}")
+ return max_prompt_length + 1
+
+ filtered_dataset = dataset.filter(
+ lambda example: compute_token_count(example) <= max_prompt_length,
+ num_proc=num_workers,
+ desc=f"Keeping prompts with length <= {max_prompt_length} tokens",
+ )
+
+ filtered_len = len(filtered_dataset)
+ filtered_count = original_len - filtered_len
+ if filtered_count > 0:
+ logger.info(
+ f"Filtered {filtered_count}/{original_len} samples ({100 * filtered_count / original_len:.1f}%) "
+ f"exceeding max_prompt_length={max_prompt_length}. Remaining: {filtered_len}"
+ )
+ else:
+ logger.info(f"All {original_len} samples within max_prompt_length={max_prompt_length}")
+
+ return filtered_dataset
+
+
+def get_vlm_dataset(
+ data_args,
+ encode_function,
+ processor,
+ get_eval=False,
+ max_prompt_length: Optional[int] = None,
+ vlm_filter: Optional["VLMFilterConfig"] = None,
+):
+ """Load and encode VLM dataset with optional filtering.
+
+ Args:
+ data_args: Data arguments containing dataset path and preprocessing settings.
+ encode_function: Function to encode the dataset.
+ processor: Processor for tokenization and image processing.
+ get_eval: Whether to load evaluation dataset.
+ max_prompt_length: Maximum prompt length for filtering. If None, filtering is disabled.
+ vlm_filter: VLMFilterConfig for filtering settings. If None, uses default values.
+ """
+ # Import here to avoid circular import
+ from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig
+
+ if vlm_filter is None:
+ vlm_filter = VLMFilterConfig()
+
cache_path = getattr(data_args, "cache_path", None)
if cache_path:
cache_path = os.path.join(cache_path, "val" if get_eval else "train")
- if cache_path and os.path.exists(cache_path):
+
+ # When filtering is enabled with max_prompt_length, skip cache to ensure correct filtering.
+ # The cached dataset may have been filtered with a different max_prompt_length.
+ # Filtering is always done after loading/encoding, and we don't cache filtered results.
+ should_filter = vlm_filter.enable and max_prompt_length is not None
+ use_cache = cache_path and os.path.exists(cache_path) and not should_filter
+ if use_cache:
dataset = load_from_disk(cache_path)
+ logger.info(f"Loaded dataset from cache: {cache_path}")
return dataset
dataset = get_dataset(data_args=data_args)
@@ -202,8 +313,31 @@ def get_vlm_dataset(data_args, encode_function, processor, get_eval=False):
desc="Encoding dataset",
)
print(f"Encoding: {dataset}")
- if cache_path:
+
+ # Filter out samples where text + image tokens exceed max_prompt_length
+ # This is done AFTER encoding but BEFORE saving to cache
+ # Uses vlm_filter config for filtering settings
+ if should_filter:
+ dataset = filter_overlong_prompts(
+ dataset=dataset,
+ processor=processor,
+ max_prompt_length=max_prompt_length,
+ prompt_key=vlm_filter.prompt_key,
+ image_key=vlm_filter.image_key,
+ image_flag_key=vlm_filter.image_flag_key,
+ num_workers=vlm_filter.num_workers,
+ )
+
+ # Only cache if no filtering was applied
+ # This prevents caching filtered datasets which may have different max_prompt_length
+ if cache_path and not should_filter:
dataset.save_to_disk(cache_path)
+ logger.info(f"Saved dataset to cache: {cache_path}")
+ elif cache_path and should_filter:
+ logger.info(
+ f"Skipping cache save because max_prompt_length={max_prompt_length} filtering was applied. "
+ "Next run will re-encode and re-filter with the same max_prompt_length."
+ )
return dataset
@@ -222,7 +356,12 @@ def __init__(self, pipeline_config: RLVRConfig):
self.tokenizer.padding_side = "left"
dataset = get_vlm_dataset(
- self.pipeline_config.actor_train.data_args, encode_function, self.processor, get_eval=False
+ self.pipeline_config.actor_train.data_args,
+ encode_function,
+ self.processor,
+ get_eval=False,
+ max_prompt_length=self.pipeline_config.prompt_length,
+ vlm_filter=self.pipeline_config.vlm_filter,
)
# update domain field, DynamicSamplingScheduler requires
dataset = dataset.map(
@@ -244,7 +383,12 @@ def __init__(self, pipeline_config: RLVRConfig):
self.val_dataset = None
if self.pipeline_config.validation and self.pipeline_config.validation.data_args:
self.val_dataset = get_vlm_dataset(
- self.pipeline_config.validation.data_args, encode_function, self.processor, get_eval=True
+ self.pipeline_config.validation.data_args,
+ encode_function,
+ self.processor,
+ get_eval=True,
+ max_prompt_length=self.pipeline_config.prompt_length,
+ vlm_filter=self.pipeline_config.vlm_filter,
)
self.val_dataset = self.val_dataset.map(
partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
@@ -316,9 +460,11 @@ def __init__(self, pipeline_config: RLVRConfig):
else:
domain_batch_size = int(domain_ratios[domain] * self.pipeline_config.rollout_batch_size)
accumulated += domain_batch_size
- generate_scheduler = ray.remote(DynamicSamplingScheduler).options(
- scheduling_strategy=NodeAffinitySchedulingStrategy(
- node_id=ray.get_runtime_context().get_node_id(), soft=False
+ generate_scheduler = (
+ ray.remote(DynamicSamplingScheduler).options(
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
+ node_id=ray.get_runtime_context().get_node_id(), soft=False
+ )
)
).remote(
pipeline_config=self.pipeline_config,
@@ -353,9 +499,11 @@ def __init__(self, pipeline_config: RLVRConfig):
if self.val_dataset:
val_pipeline_config = copy.deepcopy(self.pipeline_config)
val_pipeline_config.is_use_additional_prompts = False
- self.val_generate_scheduler = ray.remote(DynamicSamplingScheduler).options(
- scheduling_strategy=NodeAffinitySchedulingStrategy(
- node_id=ray.get_runtime_context().get_node_id(), soft=False
+ self.val_generate_scheduler = (
+ ray.remote(DynamicSamplingScheduler).options(
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
+ node_id=ray.get_runtime_context().get_node_id(), soft=False
+ )
)
).remote(
pipeline_config=val_pipeline_config,
@@ -494,9 +642,11 @@ def run(self):
metrics_mgr.add_metric("time/val_step", val_step_timer.last)
# 要按domain group by生成对应的batch
- with actor_infer_timer, actor_infer_response_timer, Timer(
- name="step_generate", logger=None
- ) as step_generate_timer:
+ with (
+ actor_infer_timer,
+ actor_infer_response_timer,
+ Timer(name="step_generate", logger=None) as step_generate_timer,
+ ):
domain_batches = {}
scheduler_refs = {}
for domain, scheduler in self.generate_schedulers.items():
@@ -527,7 +677,9 @@ def run(self):
batch.meta_info["_broadcast_non_tensor_batch"] = True
batch.meta_info["loss_mask_keys"] = ["response_mask", "final_response_mask"]
- batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object)
+ batch.non_tensor_batch["sample_uuid"] = np.array(
+ [str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object
+ )
with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer:
if self.pipeline_config.enable_reference:
ref_log_probs = self.reference.compute_log_probs(batch, blocking=True)
@@ -542,7 +694,9 @@ def run(self):
values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False)
if self.pipeline_config.enable_old_logprobs_recompute:
- old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False)
+ old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(
+ batch, blocking=False
+ )
old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs)
agg_entropy = agg_loss(
loss_mat=old_log_probs.batch["entropy"],
@@ -723,7 +877,9 @@ def val(self, global_step):
{"global_step": self.global_step, "max_steps": self.pipeline_config.max_steps, "is_training": False}
)
generate_output: DataProto = ray.get(
- self.val_generate_scheduler.get_batch.remote(data=batch, global_step=global_step, batch_size=len(self.val_dataset)),
+ self.val_generate_scheduler.get_batch.remote(
+ data=batch, global_step=global_step, batch_size=len(self.val_dataset)
+ ),
timeout=self.pipeline_config.rpc_timeout,
)
generate_output.meta_info.pop("is_offload_states", None)
diff --git a/tests/pipeline/test_vlm_filter.py b/tests/pipeline/test_vlm_filter.py
new file mode 100644
index 000000000..7091a270c
--- /dev/null
+++ b/tests/pipeline/test_vlm_filter.py
@@ -0,0 +1,274 @@
+"""Unit tests for VLM filter_overlong_prompts function."""
+
+import pytest
+from unittest.mock import MagicMock, patch
+import datasets
+
+
+class TestFilterOverlongPrompts:
+ """Tests for filter_overlong_prompts function."""
+
+ @pytest.fixture
+ def mock_processor(self):
+ """Create a mock processor for testing."""
+ processor = MagicMock()
+ processor.apply_chat_template = MagicMock(return_value="formatted prompt")
+ processor.return_value = {"input_ids": [[1, 2, 3, 4, 5]]}
+ return processor
+
+ @pytest.fixture
+ def sample_dataset(self):
+ """Create a sample dataset for testing."""
+ return datasets.Dataset.from_dict({
+ "prompt": ["short prompt", "medium length prompt here", "this is a very long prompt that exceeds the limit"],
+ "images": [[None], [None], [None]],
+ "image_flag": [False, False, False],
+ "tag": ["test", "test", "test"],
+ })
+
+ def test_filter_empty_dataset(self, mock_processor):
+ """Test that empty dataset is handled correctly."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ empty_dataset = datasets.Dataset.from_dict({
+ "prompt": [],
+ "images": [],
+ "image_flag": [],
+ })
+
+ result = filter_overlong_prompts(
+ dataset=empty_dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ )
+
+ assert len(result) == 0
+
+ def test_filter_keeps_short_prompts(self, mock_processor, sample_dataset):
+ """Test that prompts within limit are kept."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ # All prompts return 5 tokens, max is 10, so all should pass
+ result = filter_overlong_prompts(
+ dataset=sample_dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ num_workers=1,
+ )
+
+ assert len(result) == 3
+
+ def test_filter_removes_long_prompts(self, mock_processor, sample_dataset):
+ """Test that prompts exceeding limit are filtered out."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ # All prompts return 5 tokens, max is 3, so all should be filtered
+ result = filter_overlong_prompts(
+ dataset=sample_dataset,
+ processor=mock_processor,
+ max_prompt_length=3,
+ num_workers=1,
+ )
+
+ assert len(result) == 0
+
+ def test_filter_handles_parse_errors(self, mock_processor, sample_dataset):
+ """Test that parse errors result in filtering."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ # Make processor raise an exception for some samples
+ call_count = [0]
+
+ def side_effect(*args, **kwargs):
+ call_count[0] += 1
+ if call_count[0] == 2:
+ raise ValueError("Parse error")
+ return {"input_ids": [[1, 2, 3]]}
+
+ mock_processor.return_value = side_effect
+
+ result = filter_overlong_prompts(
+ dataset=sample_dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ num_workers=1,
+ )
+
+ # Second sample should be filtered due to error, 2 remain
+ assert len(result) == 2
+
+ def test_filter_with_custom_keys(self, mock_processor):
+ """Test that custom key names are respected."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ dataset = datasets.Dataset.from_dict({
+ "custom_prompt": ["test prompt"],
+ "custom_images": [[None]],
+ "custom_image_flag": [False],
+ })
+
+ result = filter_overlong_prompts(
+ dataset=dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ prompt_key="custom_prompt",
+ image_key="custom_images",
+ image_flag_key="custom_image_flag",
+ num_workers=1,
+ )
+
+ assert len(result) == 1
+
+ def test_filter_with_valid_images(self, mock_processor):
+ """Test that samples with valid images are processed correctly."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ mock_image = MagicMock()
+ dataset = datasets.Dataset.from_dict({
+ "prompt": ["prompt with image"],
+ "images": [[mock_image]],
+ "image_flag": [True],
+ })
+
+ result = filter_overlong_prompts(
+ dataset=dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ num_workers=1,
+ )
+
+ # Verify processor was called with images
+ mock_processor.assert_called()
+ assert len(result) == 1
+
+ def test_filter_boundary_condition(self, mock_processor):
+ """Test that prompts exactly at max_prompt_length are kept."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ dataset = datasets.Dataset.from_dict({
+ "prompt": ["exact length"],
+ "images": [[None]],
+ "image_flag": [False],
+ })
+
+ # Token length is 5, max is 5, should be kept (<=)
+ result = filter_overlong_prompts(
+ dataset=dataset,
+ processor=mock_processor,
+ max_prompt_length=5,
+ num_workers=1,
+ )
+
+ assert len(result) == 1
+
+ def test_filter_single_image_not_in_list(self, mock_processor):
+ """Test handling of single image not wrapped in list."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts
+
+ mock_image = MagicMock()
+ dataset = datasets.Dataset.from_dict({
+ "prompt": ["prompt"],
+ "images": [mock_image], # Single image, not in list
+ "image_flag": [True],
+ })
+
+ result = filter_overlong_prompts(
+ dataset=dataset,
+ processor=mock_processor,
+ max_prompt_length=10,
+ num_workers=1,
+ )
+
+ assert len(result) == 1
+
+
+class TestVLMFilterConfig:
+ """Tests for VLMFilterConfig dataclass."""
+
+ def test_default_values(self):
+ """Test that default values are correct."""
+ from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig
+
+ config = VLMFilterConfig()
+
+ assert config.enable is True
+ assert config.num_workers is None
+ assert config.prompt_key == "prompt"
+ assert config.image_key == "images"
+ assert config.image_flag_key == "image_flag"
+
+ def test_custom_values(self):
+ """Test that custom values can be set."""
+ from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig
+
+ config = VLMFilterConfig(
+ enable=False,
+ num_workers=4,
+ prompt_key="custom_prompt",
+ image_key="custom_images",
+ image_flag_key="custom_flag",
+ )
+
+ assert config.enable is False
+ assert config.num_workers == 4
+ assert config.prompt_key == "custom_prompt"
+ assert config.image_key == "custom_images"
+ assert config.image_flag_key == "custom_flag"
+
+
+class TestGetVLMDataset:
+ """Tests for get_vlm_dataset function integration with VLMFilterConfig."""
+
+ @pytest.fixture
+ def mock_data_args(self):
+ """Create mock data args."""
+ data_args = MagicMock()
+ data_args.cache_path = None
+ data_args.preprocessing_num_workers = 1
+ return data_args
+
+ def test_filter_disabled(self, mock_data_args):
+ """Test that filtering can be disabled via config."""
+ from roll.pipeline.rlvr.rlvr_vlm_pipeline import get_vlm_dataset
+ from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig
+
+ vlm_filter = VLMFilterConfig(enable=False)
+
+ # Mock the dependencies
+ with patch('roll.pipeline.rlvr.rlvr_vlm_pipeline.get_dataset') as mock_get_dataset:
+ mock_get_dataset.return_value = datasets.Dataset.from_dict({
+ "prompt": ["test"],
+ "images": [[None]],
+ "reward_model": [{"ground_truth": "answer"}],
+ "data_source": ["test"],
+ })
+
+ with patch('roll.pipeline.rlvr.rlvr_vlm_pipeline.encode_function') as mock_encode:
+ mock_encode.return_value = {
+ "tag": "test",
+ "images": [[None]],
+ "prompt": ["test"],
+ "ground_truth": ["answer"],
+ "reward_model": [{"ground_truth": "answer"}],
+ "image_flag": [False],
+ }
+
+ mock_processor = MagicMock()
+ mock_processor.tokenizer = MagicMock()
+ mock_processor.tokenizer.pad_token = ""
+
+ result = get_vlm_dataset(
+ data_args=mock_data_args,
+ encode_function=lambda *args, **kwargs: mock_encode.return_value,
+ processor=mock_processor,
+ get_eval=False,
+ max_prompt_length=5,
+ vlm_filter=vlm_filter,
+ )
+
+ # When disabled, dataset should not be filtered
+ assert result is not None
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])