diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index 43b8ffccb..df2e91d28 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -17,6 +17,35 @@ class DatasetFilterConfig: max_difficulty: Optional[float] = None num_samples: int = 0 + +@dataclass +class VLMFilterConfig: + """Configuration for filtering overlong prompts in VLM pipeline. + + Similar to verl's RLHFDataset config options. + """ + enable: bool = field( + default=True, + metadata={"help": "Whether to filter out samples exceeding max_prompt_length."} + ) + num_workers: Optional[int] = field( + default=None, + metadata={"help": "Number of workers for parallel filtering. Defaults to max(1, cpu_count // 4)."} + ) + prompt_key: str = field( + default="prompt", + metadata={"help": "Key for the prompt text in the dataset."} + ) + image_key: str = field( + default="images", + metadata={"help": "Key for the images in the dataset."} + ) + image_flag_key: str = field( + default="image_flag", + metadata={"help": "Key for the image flag indicating valid images."} + ) + + @dataclass class RewardFilterConfig: type: Literal["no_filter", "mean_filter", "std_filter"] = field( @@ -88,6 +117,10 @@ class RLVRConfig(PPOConfig): default_factory=DatasetFilterConfig, metadata={"help": "Configuration for filtering dataset by source and difficulty"}, ) + vlm_filter: VLMFilterConfig = field( + default_factory=VLMFilterConfig, + metadata={"help": "Configuration for filtering overlong prompts in VLM pipeline"}, + ) num_return_sequences_in_group: int = field( default=1, metadata={"help": "The number of return sequences in one group, used in generation_args."} diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index dc7af456c..45e657558 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -51,7 +51,7 @@ def format_prompt(prompt, processor, use_image=True, prompt_image_token=None): - question_template = "{Question} Output the thinking process in and final answer (number) in tags." + question_template = "{Question}" if isinstance(prompt, list): messages = prompt else: @@ -162,12 +162,123 @@ def encode_function( return encodings -def get_vlm_dataset(data_args, encode_function, processor, get_eval=False): +def filter_overlong_prompts( + dataset: datasets.Dataset, + processor: ProcessorMixin, + max_prompt_length: int, + prompt_key: str = "prompt", + image_key: str = "images", + image_flag_key: str = "image_flag", + num_workers: Optional[int] = None, +) -> datasets.Dataset: + """Filter out samples where text + image tokens exceed max_prompt_length. + + Note: Returns a new filtered dataset; the original dataset is not modified. + + Args: + dataset: The dataset to filter (already encoded with formatted prompts and processed images). + processor: The processor to use for tokenization. + max_prompt_length: Maximum allowed token length (inclusive). + prompt_key: Key for the prompt text in the dataset. + image_key: Key for the images in the dataset. + image_flag_key: Key for the image flag indicating valid images. + num_workers: Number of processes for parallel filtering. Defaults to max(1, cpu_count // 4). + + Returns: + Filtered dataset with samples within the token length limit. + """ + # Default num_workers similar to verl's approach + if num_workers is None: + num_workers = max(1, (os.cpu_count() or 4) // 4) + + original_len = len(dataset) + if original_len == 0: + logger.info("Dataset is empty, skipping filtering") + return dataset + + def compute_token_count(example) -> int: + """Compute token count for a sample. Returns max_prompt_length + 1 on parse failure.""" + try: + # Prompt is already formatted by encode_function, use directly + prompt = example[prompt_key] + if not isinstance(prompt, str): + # Fallback: apply chat template only if prompt is not already a string + prompt = processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + + # Use already-processed images from encode_function + images = None + if example.get(image_flag_key) and example.get(image_key): + images = example[image_key] + # Handle single image or list of images + if not isinstance(images, list): + images = [images] + + # Calculate token length (same path as DataCollatorWithPaddingForMM) + inputs = processor(text=prompt, images=images) + input_ids = inputs["input_ids"][0] + + return len(input_ids) + + except Exception as e: + # Return max_prompt_length + 1 on any error to filter out the sample + logger.error(f"Error processing sample during filter, skipping: {e}") + return max_prompt_length + 1 + + filtered_dataset = dataset.filter( + lambda example: compute_token_count(example) <= max_prompt_length, + num_proc=num_workers, + desc=f"Keeping prompts with length <= {max_prompt_length} tokens", + ) + + filtered_len = len(filtered_dataset) + filtered_count = original_len - filtered_len + if filtered_count > 0: + logger.info( + f"Filtered {filtered_count}/{original_len} samples ({100 * filtered_count / original_len:.1f}%) " + f"exceeding max_prompt_length={max_prompt_length}. Remaining: {filtered_len}" + ) + else: + logger.info(f"All {original_len} samples within max_prompt_length={max_prompt_length}") + + return filtered_dataset + + +def get_vlm_dataset( + data_args, + encode_function, + processor, + get_eval=False, + max_prompt_length: Optional[int] = None, + vlm_filter: Optional["VLMFilterConfig"] = None, +): + """Load and encode VLM dataset with optional filtering. + + Args: + data_args: Data arguments containing dataset path and preprocessing settings. + encode_function: Function to encode the dataset. + processor: Processor for tokenization and image processing. + get_eval: Whether to load evaluation dataset. + max_prompt_length: Maximum prompt length for filtering. If None, filtering is disabled. + vlm_filter: VLMFilterConfig for filtering settings. If None, uses default values. + """ + # Import here to avoid circular import + from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig + + if vlm_filter is None: + vlm_filter = VLMFilterConfig() + cache_path = getattr(data_args, "cache_path", None) if cache_path: cache_path = os.path.join(cache_path, "val" if get_eval else "train") - if cache_path and os.path.exists(cache_path): + + # When filtering is enabled with max_prompt_length, skip cache to ensure correct filtering. + # The cached dataset may have been filtered with a different max_prompt_length. + # Filtering is always done after loading/encoding, and we don't cache filtered results. + should_filter = vlm_filter.enable and max_prompt_length is not None + use_cache = cache_path and os.path.exists(cache_path) and not should_filter + if use_cache: dataset = load_from_disk(cache_path) + logger.info(f"Loaded dataset from cache: {cache_path}") return dataset dataset = get_dataset(data_args=data_args) @@ -202,8 +313,31 @@ def get_vlm_dataset(data_args, encode_function, processor, get_eval=False): desc="Encoding dataset", ) print(f"Encoding: {dataset}") - if cache_path: + + # Filter out samples where text + image tokens exceed max_prompt_length + # This is done AFTER encoding but BEFORE saving to cache + # Uses vlm_filter config for filtering settings + if should_filter: + dataset = filter_overlong_prompts( + dataset=dataset, + processor=processor, + max_prompt_length=max_prompt_length, + prompt_key=vlm_filter.prompt_key, + image_key=vlm_filter.image_key, + image_flag_key=vlm_filter.image_flag_key, + num_workers=vlm_filter.num_workers, + ) + + # Only cache if no filtering was applied + # This prevents caching filtered datasets which may have different max_prompt_length + if cache_path and not should_filter: dataset.save_to_disk(cache_path) + logger.info(f"Saved dataset to cache: {cache_path}") + elif cache_path and should_filter: + logger.info( + f"Skipping cache save because max_prompt_length={max_prompt_length} filtering was applied. " + "Next run will re-encode and re-filter with the same max_prompt_length." + ) return dataset @@ -222,7 +356,12 @@ def __init__(self, pipeline_config: RLVRConfig): self.tokenizer.padding_side = "left" dataset = get_vlm_dataset( - self.pipeline_config.actor_train.data_args, encode_function, self.processor, get_eval=False + self.pipeline_config.actor_train.data_args, + encode_function, + self.processor, + get_eval=False, + max_prompt_length=self.pipeline_config.prompt_length, + vlm_filter=self.pipeline_config.vlm_filter, ) # update domain field, DynamicSamplingScheduler requires dataset = dataset.map( @@ -244,7 +383,12 @@ def __init__(self, pipeline_config: RLVRConfig): self.val_dataset = None if self.pipeline_config.validation and self.pipeline_config.validation.data_args: self.val_dataset = get_vlm_dataset( - self.pipeline_config.validation.data_args, encode_function, self.processor, get_eval=True + self.pipeline_config.validation.data_args, + encode_function, + self.processor, + get_eval=True, + max_prompt_length=self.pipeline_config.prompt_length, + vlm_filter=self.pipeline_config.vlm_filter, ) self.val_dataset = self.val_dataset.map( partial(update_dataset_domain, self.pipeline_config.tag_2_domain), @@ -316,9 +460,11 @@ def __init__(self, pipeline_config: RLVRConfig): else: domain_batch_size = int(domain_ratios[domain] * self.pipeline_config.rollout_batch_size) accumulated += domain_batch_size - generate_scheduler = ray.remote(DynamicSamplingScheduler).options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), soft=False + generate_scheduler = ( + ray.remote(DynamicSamplingScheduler).options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=False + ) ) ).remote( pipeline_config=self.pipeline_config, @@ -353,9 +499,11 @@ def __init__(self, pipeline_config: RLVRConfig): if self.val_dataset: val_pipeline_config = copy.deepcopy(self.pipeline_config) val_pipeline_config.is_use_additional_prompts = False - self.val_generate_scheduler = ray.remote(DynamicSamplingScheduler).options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), soft=False + self.val_generate_scheduler = ( + ray.remote(DynamicSamplingScheduler).options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=False + ) ) ).remote( pipeline_config=val_pipeline_config, @@ -494,9 +642,11 @@ def run(self): metrics_mgr.add_metric("time/val_step", val_step_timer.last) # 要按domain group by生成对应的batch - with actor_infer_timer, actor_infer_response_timer, Timer( - name="step_generate", logger=None - ) as step_generate_timer: + with ( + actor_infer_timer, + actor_infer_response_timer, + Timer(name="step_generate", logger=None) as step_generate_timer, + ): domain_batches = {} scheduler_refs = {} for domain, scheduler in self.generate_schedulers.items(): @@ -527,7 +677,9 @@ def run(self): batch.meta_info["_broadcast_non_tensor_batch"] = True batch.meta_info["loss_mask_keys"] = ["response_mask", "final_response_mask"] - batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object) + batch.non_tensor_batch["sample_uuid"] = np.array( + [str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object + ) with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: if self.pipeline_config.enable_reference: ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) @@ -542,7 +694,9 @@ def run(self): values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) if self.pipeline_config.enable_old_logprobs_recompute: - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) + old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( + batch, blocking=False + ) old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) agg_entropy = agg_loss( loss_mat=old_log_probs.batch["entropy"], @@ -723,7 +877,9 @@ def val(self, global_step): {"global_step": self.global_step, "max_steps": self.pipeline_config.max_steps, "is_training": False} ) generate_output: DataProto = ray.get( - self.val_generate_scheduler.get_batch.remote(data=batch, global_step=global_step, batch_size=len(self.val_dataset)), + self.val_generate_scheduler.get_batch.remote( + data=batch, global_step=global_step, batch_size=len(self.val_dataset) + ), timeout=self.pipeline_config.rpc_timeout, ) generate_output.meta_info.pop("is_offload_states", None) diff --git a/tests/pipeline/test_vlm_filter.py b/tests/pipeline/test_vlm_filter.py new file mode 100644 index 000000000..7091a270c --- /dev/null +++ b/tests/pipeline/test_vlm_filter.py @@ -0,0 +1,274 @@ +"""Unit tests for VLM filter_overlong_prompts function.""" + +import pytest +from unittest.mock import MagicMock, patch +import datasets + + +class TestFilterOverlongPrompts: + """Tests for filter_overlong_prompts function.""" + + @pytest.fixture + def mock_processor(self): + """Create a mock processor for testing.""" + processor = MagicMock() + processor.apply_chat_template = MagicMock(return_value="formatted prompt") + processor.return_value = {"input_ids": [[1, 2, 3, 4, 5]]} + return processor + + @pytest.fixture + def sample_dataset(self): + """Create a sample dataset for testing.""" + return datasets.Dataset.from_dict({ + "prompt": ["short prompt", "medium length prompt here", "this is a very long prompt that exceeds the limit"], + "images": [[None], [None], [None]], + "image_flag": [False, False, False], + "tag": ["test", "test", "test"], + }) + + def test_filter_empty_dataset(self, mock_processor): + """Test that empty dataset is handled correctly.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + empty_dataset = datasets.Dataset.from_dict({ + "prompt": [], + "images": [], + "image_flag": [], + }) + + result = filter_overlong_prompts( + dataset=empty_dataset, + processor=mock_processor, + max_prompt_length=10, + ) + + assert len(result) == 0 + + def test_filter_keeps_short_prompts(self, mock_processor, sample_dataset): + """Test that prompts within limit are kept.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + # All prompts return 5 tokens, max is 10, so all should pass + result = filter_overlong_prompts( + dataset=sample_dataset, + processor=mock_processor, + max_prompt_length=10, + num_workers=1, + ) + + assert len(result) == 3 + + def test_filter_removes_long_prompts(self, mock_processor, sample_dataset): + """Test that prompts exceeding limit are filtered out.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + # All prompts return 5 tokens, max is 3, so all should be filtered + result = filter_overlong_prompts( + dataset=sample_dataset, + processor=mock_processor, + max_prompt_length=3, + num_workers=1, + ) + + assert len(result) == 0 + + def test_filter_handles_parse_errors(self, mock_processor, sample_dataset): + """Test that parse errors result in filtering.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + # Make processor raise an exception for some samples + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 2: + raise ValueError("Parse error") + return {"input_ids": [[1, 2, 3]]} + + mock_processor.return_value = side_effect + + result = filter_overlong_prompts( + dataset=sample_dataset, + processor=mock_processor, + max_prompt_length=10, + num_workers=1, + ) + + # Second sample should be filtered due to error, 2 remain + assert len(result) == 2 + + def test_filter_with_custom_keys(self, mock_processor): + """Test that custom key names are respected.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + dataset = datasets.Dataset.from_dict({ + "custom_prompt": ["test prompt"], + "custom_images": [[None]], + "custom_image_flag": [False], + }) + + result = filter_overlong_prompts( + dataset=dataset, + processor=mock_processor, + max_prompt_length=10, + prompt_key="custom_prompt", + image_key="custom_images", + image_flag_key="custom_image_flag", + num_workers=1, + ) + + assert len(result) == 1 + + def test_filter_with_valid_images(self, mock_processor): + """Test that samples with valid images are processed correctly.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + mock_image = MagicMock() + dataset = datasets.Dataset.from_dict({ + "prompt": ["prompt with image"], + "images": [[mock_image]], + "image_flag": [True], + }) + + result = filter_overlong_prompts( + dataset=dataset, + processor=mock_processor, + max_prompt_length=10, + num_workers=1, + ) + + # Verify processor was called with images + mock_processor.assert_called() + assert len(result) == 1 + + def test_filter_boundary_condition(self, mock_processor): + """Test that prompts exactly at max_prompt_length are kept.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + dataset = datasets.Dataset.from_dict({ + "prompt": ["exact length"], + "images": [[None]], + "image_flag": [False], + }) + + # Token length is 5, max is 5, should be kept (<=) + result = filter_overlong_prompts( + dataset=dataset, + processor=mock_processor, + max_prompt_length=5, + num_workers=1, + ) + + assert len(result) == 1 + + def test_filter_single_image_not_in_list(self, mock_processor): + """Test handling of single image not wrapped in list.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import filter_overlong_prompts + + mock_image = MagicMock() + dataset = datasets.Dataset.from_dict({ + "prompt": ["prompt"], + "images": [mock_image], # Single image, not in list + "image_flag": [True], + }) + + result = filter_overlong_prompts( + dataset=dataset, + processor=mock_processor, + max_prompt_length=10, + num_workers=1, + ) + + assert len(result) == 1 + + +class TestVLMFilterConfig: + """Tests for VLMFilterConfig dataclass.""" + + def test_default_values(self): + """Test that default values are correct.""" + from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig + + config = VLMFilterConfig() + + assert config.enable is True + assert config.num_workers is None + assert config.prompt_key == "prompt" + assert config.image_key == "images" + assert config.image_flag_key == "image_flag" + + def test_custom_values(self): + """Test that custom values can be set.""" + from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig + + config = VLMFilterConfig( + enable=False, + num_workers=4, + prompt_key="custom_prompt", + image_key="custom_images", + image_flag_key="custom_flag", + ) + + assert config.enable is False + assert config.num_workers == 4 + assert config.prompt_key == "custom_prompt" + assert config.image_key == "custom_images" + assert config.image_flag_key == "custom_flag" + + +class TestGetVLMDataset: + """Tests for get_vlm_dataset function integration with VLMFilterConfig.""" + + @pytest.fixture + def mock_data_args(self): + """Create mock data args.""" + data_args = MagicMock() + data_args.cache_path = None + data_args.preprocessing_num_workers = 1 + return data_args + + def test_filter_disabled(self, mock_data_args): + """Test that filtering can be disabled via config.""" + from roll.pipeline.rlvr.rlvr_vlm_pipeline import get_vlm_dataset + from roll.pipeline.rlvr.rlvr_config import VLMFilterConfig + + vlm_filter = VLMFilterConfig(enable=False) + + # Mock the dependencies + with patch('roll.pipeline.rlvr.rlvr_vlm_pipeline.get_dataset') as mock_get_dataset: + mock_get_dataset.return_value = datasets.Dataset.from_dict({ + "prompt": ["test"], + "images": [[None]], + "reward_model": [{"ground_truth": "answer"}], + "data_source": ["test"], + }) + + with patch('roll.pipeline.rlvr.rlvr_vlm_pipeline.encode_function') as mock_encode: + mock_encode.return_value = { + "tag": "test", + "images": [[None]], + "prompt": ["test"], + "ground_truth": ["answer"], + "reward_model": [{"ground_truth": "answer"}], + "image_flag": [False], + } + + mock_processor = MagicMock() + mock_processor.tokenizer = MagicMock() + mock_processor.tokenizer.pad_token = "" + + result = get_vlm_dataset( + data_args=mock_data_args, + encode_function=lambda *args, **kwargs: mock_encode.return_value, + processor=mock_processor, + get_eval=False, + max_prompt_length=5, + vlm_filter=vlm_filter, + ) + + # When disabled, dataset should not be filtered + assert result is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])