Skip to content

Conversation

@bizhongan414
Copy link

What does this PR do?

Bug Report: filter_overlong_prompts Fails for Multimodal Data

🐛 Problem Summary

When using filter_overlong_prompts=True with multimodal datasets (containing images/videos), the filtering mechanism completely fails to correctly calculate prompt lengths, causing:

  1. Training crashes with RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 2048 but got size 2597
  2. Validation failures due to inconsistent sequence lengths
  3. Silent data corruption where overlong multimodal samples incorrectly pass filtering

🔍 Root Cause Analysis

Two-Level Bug in verl/utils/dataset/rl_dataset.py

Level 1: Logic Error with pop() Side Effect

In _build_messages(), the code uses example.pop() to extract images/videos:

def _build_messages(self, example: dict):
    messages: list = example[self.prompt_key]
    # ❌ pop() deletes the keys from the example dict
    images = example.pop(self.image_key, None) or []
    videos = example.pop(self.video_key, None) or []
    # ... process and embed images/videos into messages ...
    return messages

Level 2: Broken Conditional Check

In maybe_filter_out_long_prompts(), the code checks for images after they've been deleted:

def doc2len(doc) -> int:
    # Step 1: Call _build_messages, which pops image_key/video_key
    messages = self._build_messages(doc)
    
    raw_prompt = self.processor.apply_chat_template(messages, ...)
    
    # Step 2: Check if image_key exists in doc
    if image_key in doc and doc[image_key]:  # ❌ Always False!
        images = [process_image(...) for image in doc[image_key]]
    else:
        images = None  # ← Multimodal samples fall here!
    
    # Step 3: Calculate length with images=None
    return len(processor(text=[raw_prompt], images=None, ...)["input_ids"][0])
    #                                        ^^^^^^^^^^^
    #                     Image tokens completely missing!

Result: Severe Length Underestimation

Stage Calculation Result
Filtering Text only (images=None) ~2040 tokens ✅ Passes filter
Runtime (Agent Loop) Text + Images (correct) ~2597 tokens ❌ Exceeds limit

📊 Impact Assessment

Affected Scenarios

  • ✅ All multimodal RLHF training (Qwen-VL, LLaVA, etc.)
  • ✅ Any dataset with filter_overlong_prompts=True
  • ✅ Both training and validation sets

Severity

  • 🔴 Critical: Training/validation crashes
  • 🔴 Data Integrity: Overlong samples contaminate dataset
  • 🔴 Silent Failure: No warning that filtering is broken

Example Numerical Impact

For a multimodal sample with:

  • Text: 100 tokens
  • Single high-res image: 1497 tokens
# Original buggy code:
filtered_length = 100 tokens  # ✅ Passes filter (< 2048)
actual_length = 1597 tokens   # ⚠️ Crashes at runtime

🔧 Fix Implementation

Solution: Unified Vision Processing

Replace process_image() from verl.utils.dataset.vision_utils with process_vision_info() from qwen_vl_utils to match Agent Loop's behavior.

Key Changes in maybe_filter_out_long_prompts()

def doc2len(doc) -> int:
    try:
        messages = self._build_messages(doc)
        
        # ✅ Use the same vision processing as Agent Loop
        from qwen_vl_utils import process_vision_info
        
        # ✅ Use processor's actual patch_size
        actual_patch_size = (
            self.processor.image_processor.patch_size 
            if hasattr(self.processor, 'image_processor') 
            else self.image_patch_size
        )
        
        # ✅ Extract images/videos from messages (not from doc!)
        images, videos = process_vision_info(
            messages,  # ← Images are embedded here by _build_messages
            image_patch_size=actual_patch_size,
            return_video_metadata=True
        )
        
        # ✅ Split videos and metadatas (same as Agent Loop)
        if videos:
            videos, video_metadatas = zip(*videos, strict=True)
            videos, video_metadatas = list(videos), list(video_metadatas)
        else:
            video_metadatas = None
        
        raw_prompt = self.processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False, **apply_kwargs
        )
        
        # ✅ Calculate length with correct images/videos
        return len(
            processor(
                text=[raw_prompt],
                images=images,
                videos=videos,
                video_metadatas=video_metadatas,
                return_tensors="pt",
                do_sample_frames=False,
            )["input_ids"][0]
        )

Why This Fix Works

  1. process_vision_info(messages) extracts images/videos from messages.content, not from doc[image_key]
  2. Bypasses the pop() issue because images are already embedded in messages
  3. Matches Agent Loop by using identical vision processing logic
  4. Ensures consistent parameters: image_patch_size, video_metadatas, return_tensors, do_sample_frames

🧪 Verification

Test Setup

  • Dataset: MathVision (multimodal)
  • Config: filter_overlong_prompts=True, max_prompt_length=2048
  • Model: Qwen3-VL-4B-Instruct

Results

Version Filtering Behavior Validation Result
Old (Buggy) 0 samples filtered RuntimeError: Expected size 2048 but got size 2597
New (Fixed) Correctly filters overlong samples ✅ No runtime errors

Log Evidence

Before Fix:

[OLD VERSION] 📊 Filtering Results:
Original: 1000 samples
Filtered: 1000 samples
Removed:  0 samples (0.00%)
⚠️ Prompt length 2597 exceeds max_prompt_len 2048  # Runtime error

After Fix:

[NEW VERSION] 📊 Filtering Results:
Original: 1000 samples
Filtered: 985 samples
Removed:  15 samples (1.50%)
✅ All prompts within limit, no runtime errors

🎯 Benefits of This Fix

  1. Correctness: filter_overlong_prompts works as intended
  2. Stability: Eliminates validation/training crashes
  3. Consistency: Unifies vision processing across filtering and runtime
  4. Maintainability: Reduces code duplication

📝 Additional Notes

  • Backward Compatibility: This fix may filter more samples than before, but this is the correct behavior
  • Performance: No significant performance impact, as filtering is done once during dataset initialization
  • Related Issue: This also aligns image_patch_size calculation with the processor's actual configuration

Related Files:

  • verl/utils/dataset/rl_dataset.py (fixed)
  • verl/experimental/agent_loop/agent_loop.py (reference implementation)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where filter_overlong_prompts failed for multimodal data, leading to crashes and data corruption. The detailed analysis in the description clearly identifies the root cause related to the misuse of pop() and inconsistent vision processing logic. The fix is well-implemented, refactoring the length calculation in maybe_filter_out_long_prompts to use a unified process_vision_info function that correctly extracts vision data from the messages object. This aligns the filtering logic with the runtime behavior in AgentLoop, ensuring consistent and accurate prompt length calculation. The changes are clean, targeted, and effectively resolve the reported issue. Overall, this is an excellent contribution that significantly improves the stability and correctness of multimodal data handling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant