diff --git a/tico/quantization/evaluation/mmmu_eval_utils.py b/tico/quantization/evaluation/mmmu_eval_utils.py index 512f6a61..29bfc2dc 100644 --- a/tico/quantization/evaluation/mmmu_eval_utils.py +++ b/tico/quantization/evaluation/mmmu_eval_utils.py @@ -19,49 +19,64 @@ import torch from datasets import load_dataset -from tico.quantization.evaluation.vlm_eval_utils import generate_answer - - -MMMU_DATASET = "MMMU/MMMU" - -MMMU_SUBJECTS: list[str] = [ - "Accounting", - "Agriculture", - "Architecture_and_Engineering", - "Art", - "Art_Theory", - "Basic_Medical_Science", - "Biology", - "Chemistry", - "Clinical_Medicine", - "Computer_Science", - "Design", - "Diagnostics_and_Laboratory_Medicine", - "Economics", - "Electronics", - "Energy_and_Power", - "Finance", - "Geography", - "History", - "Literature", - "Manage", - "Marketing", - "Materials", - "Math", - "Mechanical_Engineering", - "Music", - "Pharmacy", - "Physics", - "Psychology", - "Public_Health", - "Sociology", -] - -MMMU_SPLITS: list[str] = [ - "dev", - "validation", - "test", -] +from tico.quantization.evaluation.vlm_eval_utils import ( + generate_answer, + generate_image_only_answer, +) + + +MMMU_DATASETS = ["MMMU/MMMU", "MMMU/MMMU_Pro"] + +MMMU_SUBJECTS: dict[str, list[str]] = { + "MMMU/MMMU": [ + "Accounting", + "Agriculture", + "Architecture_and_Engineering", + "Art", + "Art_Theory", + "Basic_Medical_Science", + "Biology", + "Chemistry", + "Clinical_Medicine", + "Computer_Science", + "Design", + "Diagnostics_and_Laboratory_Medicine", + "Economics", + "Electronics", + "Energy_and_Power", + "Finance", + "Geography", + "History", + "Literature", + "Manage", + "Marketing", + "Materials", + "Math", + "Mechanical_Engineering", + "Music", + "Pharmacy", + "Physics", + "Psychology", + "Public_Health", + "Sociology", + ], + "MMMU/MMMU_Pro": [ + "standard (10 options)", + "standard (4 options)", + "vision", + ], +} + +MMMU_SPLITS: dict[str, list[str]] = { + "MMMU/MMMU": [ + "dev", + "validation", + "test", + ], + "MMMU/MMMU_Pro": [ + "test", + ], +} def take_from_dataset(ds, start: int, n: int) -> Iterable[dict[str, Any]]: @@ -76,20 +91,25 @@ def take_from_dataset(ds, start: int, n: int) -> Iterable[dict[str, Any]]: def load_data( + dataset: str, subject: str, - split: str = "validation", + split: str, start: int = 0, n_samples: int = -1, streaming: bool = True, ) -> Iterable[dict[str, Any]]: - if subject not in MMMU_SUBJECTS: + + if dataset not in MMMU_DATASETS: + raise ValueError(f"Invalid dataset '{dataset}'") + + if subject not in MMMU_SUBJECTS[dataset]: raise ValueError(f"Invalid subject '{subject}'") - if split not in MMMU_SPLITS: + if split not in MMMU_SPLITS[dataset]: raise ValueError(f"Invalid split '{split}'") ds: Iterable[dict[str, Any]] = load_dataset( - path=MMMU_DATASET, + path=dataset, name=subject, split=split, streaming=streaming, @@ -109,8 +129,8 @@ def get_item_mmmu(ex: dict[str, Any]) -> dict[str, Any]: return { "id": ex["id"], - "image": ex["image_1"], - "question": ex["question"], + "image": ex["image_1"] if "image_1" in ex else ex["image"], + "question": ex["question"] if "question" in ex else "", "choices": choices, "answer": ex["answer"], } @@ -203,15 +223,30 @@ def extract_answer(generated_text: str) -> str | None: """ text = generated_text.strip() - # Look for standalone letter [A-H] at the beginning, e.g. "A", "a", "A.", "a.", "A. Answer", "A Answer" - first_char_match = re.match(r"^([A-H])([.\s]+[^\s]+)?\.?$", text, re.IGNORECASE) + # Look for a letter at the beginning, e.g. "A", "A.", "(A)", "A Answer". + first_char_match = re.match( + r"^\s*\(?([A-J])\)?(?:[.)\s]|$)", + text, + re.IGNORECASE, + ) if first_char_match: return first_char_match.group(1).upper() + # Common verbose outputs, e.g. "The answer is C", "Answer: C", "Option C". + answer_match = re.search( + r"\b(?:answer|option|choice)\s*(?:is|:)?\s*\(?([A-J])\)?\b", + text, + re.IGNORECASE, + ) + if answer_match: + return answer_match.group(1).upper() + return text def load_few_shot_examples( + dataset: str, + split: str, subject: str, n_shots: int = 5, ) -> list[dict[str, Any]]: @@ -219,6 +254,8 @@ def load_few_shot_examples( Load few-shot examples for a given MMMU subject from the 'dev' split. Args: + dataset: Dataset name. + split: Split name (e.g. 'train', 'test', 'validation'). subject: The subject name. n_shots: Number of few-shot examples to load. @@ -229,8 +266,10 @@ def load_few_shot_examples( return [] ds = load_data( + dataset=dataset, subject=subject, - split="dev", + split=split, + start=0, n_samples=n_shots, streaming=True, ) @@ -238,9 +277,16 @@ def load_few_shot_examples( return [get_item_mmmu(ex) for ex in ds] +def is_mmmu_pro_vision(dataset: str, subject: str) -> bool: + return dataset == "MMMU/MMMU_Pro" and subject == "vision" + + def evaluate_subject( model, processor, + dataset: str, + eval_split: str, + few_shot_split: str, subject: str, device: str | torch.device, max_new_tokens: int, @@ -255,7 +301,10 @@ def evaluate_subject( Args: model: Language model with generation capability. - tokenizer: Matching tokenizer for the model. + processor: Matching processor for the model. + dataset: Dataset name. + eval_split: Split name for evaluation (e.g. 'train', 'test', 'validation'). + few_shot_split: Split name for few-shot examples (e.g. 'train', 'test', 'validation'). subject: The MMMU subject to evaluate. device: Device for inference. n_shots: Number of few-shot examples. @@ -267,11 +316,31 @@ def evaluate_subject( Returns: A tuple of (correct_count, total_count, skipped_count). """ - few_shot_examples = load_few_shot_examples(subject=subject, n_shots=n_shots) + vision_only = is_mmmu_pro_vision(dataset, subject) + if vision_only: + if n_shots > 0 and verbose: + print( + "\n[WARNING] MMMU-Pro vision subset is evaluated image-only; " + f"ignoring n_shots={n_shots}." + ) + few_shot_examples: list[dict[str, Any]] = [] + else: + few_shot_examples = load_few_shot_examples( + dataset=dataset, split=few_shot_split, subject=subject, n_shots=n_shots + ) + + # If we take few-shot examples from the same split as evaluation examples, + # then exclude few-shot examples from the evaluation set by adjusting start argument to load_data. + if few_shot_examples and eval_split == few_shot_split: + start = n_shots + else: + start = 0 test_data = load_data( + dataset=dataset, subject=subject, - split="validation", + split=eval_split, + start=start, n_samples=n_samples, streaming=True, ) @@ -283,7 +352,7 @@ def evaluate_subject( ex: dict[str, Any] for ex in test_data: # Skip questions with multiple images - if ex["image_2"] is not None: + if "image_2" in ex and ex["image_2"] is not None: skipped += 1 if verbose: question: str = ex["question"] @@ -292,23 +361,59 @@ def evaluate_subject( item = get_item_mmmu(ex) - prompt = build_few_shot_prompt( - question=item["question"], - choices=item["choices"], - subject=subject, - few_shot_examples=few_shot_examples, - ) + if vision_only: + prompt = "" + else: + prompt = build_few_shot_prompt( + question=item["question"], + choices=item["choices"], + subject=subject, + few_shot_examples=few_shot_examples, + ) - generated = generate_answer( - model=model, - processor=processor, - question=prompt, - image=item["image"], - device=device, - max_new_tokens=max_new_tokens, - max_seq_len=max_seq_len, - temperature=temperature, - ) + try: + if vision_only: + generated = generate_image_only_answer( + model=model, + processor=processor, + image=item["image"], + question="Answer the multiple-choice question shown in the image. Return only one letter from A to J.", + device=device, + max_new_tokens=max_new_tokens, + max_seq_len=max_seq_len, + temperature=temperature, + ) + else: + generated = generate_answer( + model=model, + processor=processor, + question=prompt, + image=item["image"], + device=device, + max_new_tokens=max_new_tokens, + max_seq_len=max_seq_len, + temperature=temperature, + ) + except ValueError as error: + if "Mismatch in `image` token count between text and `input_ids`." in str( + error + ): + if verbose: + print( + f"\n[WARNING] prompt too long for the specified max_seq_len={max_seq_len}. Skipping." + ) + print(f"Error: {error}") + print(f"Prompt: {prompt}") + skipped += 1 + continue + else: + raise error + except RuntimeError as error: + if verbose: + print(f"[ERROR]: {error}") + print(f"Prompt: {prompt}") + skipped += 1 + continue predicted = extract_answer(generated) gold = item["answer"].upper() @@ -319,7 +424,10 @@ def evaluate_subject( if verbose: print(f"\n[Sample {total}] Subject: {subject}") - print(f"Q: {item['question'][:100]}...") + if vision_only: + print("Q: ") + else: + print(f"Q: {item['question'][:100]}...") print(f"Choices: {item['choices']}") print( f"Generated: {generated}, Predicted: {predicted}, Gold: {gold}, Correct: {is_correct}" @@ -331,6 +439,7 @@ def evaluate_subject( def evaluate_mmmu( model, processor, + dataset: str, subjects: list[str] | None = None, device: str | torch.device = "cuda", n_shots: int = 5, @@ -345,7 +454,8 @@ def evaluate_mmmu( Args: model: Language model with generation capability. - tokenizer: Matching tokenizer for the model. + processor: Matching processor for the model. + dataset: Dataset name. subjects: List of subjects to evaluate. Use None for all subjects. device: Device for inference. n_shots: Number of few-shot examples per subject. @@ -357,8 +467,14 @@ def evaluate_mmmu( Returns: Aggregated results dictionary in '{ subject: (correct, total, skipped) }' format. """ + if dataset not in MMMU_DATASETS: + raise ValueError(f"Invalid dataset '{dataset}'") + if subjects is None: - subjects = MMMU_SUBJECTS + subjects = MMMU_SUBJECTS[dataset] + + eval_split = "validation" if dataset == "MMMU/MMMU" else "test" + few_shot_split = "test" # { subject: (correct, total) } results: dict[str, tuple[int, int, int]] = {} @@ -370,6 +486,9 @@ def evaluate_mmmu( correct, total, skipped = evaluate_subject( model=model, processor=processor, + dataset=dataset, + eval_split=eval_split, + few_shot_split=few_shot_split, subject=subject, device=device, n_shots=n_shots, diff --git a/tico/quantization/evaluation/vlm_eval_utils.py b/tico/quantization/evaluation/vlm_eval_utils.py index 18153aa8..f4aa18fc 100644 --- a/tico/quantization/evaluation/vlm_eval_utils.py +++ b/tico/quantization/evaluation/vlm_eval_utils.py @@ -370,6 +370,85 @@ def generate_answer( return processor.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() +@torch.no_grad() +def generate_image_only_answer( + model, + processor, + image, + device: str | torch.device, + question: str | None = None, + max_new_tokens: int = 16, + temperature: float = 0.0, + max_seq_len: int | None = None, +) -> str: + """ + Generate an answer from the image only. + + Args: + model: Vision-language generation model. + processor: Matching processor for the model. + image: Input image. + question: Optional text question. + device: Device on which generation should run. + max_new_tokens: Maximum number of generated tokens. + temperature: Sampling temperature. Greedy decoding is used when this + value is less than or equal to zero. + max_seq_len: Optional maximum text sequence length for processor + tokenization. + + Returns: + The decoded model answer string. + """ + content: list = [{"type": "image"}] + + if question is not None: + content.append( + { + "type": "text", + "text": question, + } + ) + + messages = [ + { + "role": "user", + "content": content, + } + ] + + prompt = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + processor_kwargs: dict[str, Any] = { + "text": prompt, + "images": image, + "return_tensors": "pt", + } + if max_seq_len is not None and max_seq_len > 0: + processor_kwargs["truncation"] = True + processor_kwargs["max_length"] = max_seq_len + + inputs = processor(**processor_kwargs) + inputs = move_inputs_to_device(inputs, device) + + do_sample = temperature > 0.0 + gen_kwargs: dict[str, Any] = { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + } + if do_sample: + gen_kwargs["temperature"] = temperature + + out_ids = model.generate(**inputs, **gen_kwargs) + input_len = inputs["input_ids"].shape[1] + gen_ids = out_ids[0, input_len:] + + return processor.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + class CocoResult(TypedDict): image_id: str caption: str diff --git a/tico/quantization/wrapq/examples/quantize_qwen3_vl_with_gptq.py b/tico/quantization/wrapq/examples/quantize_qwen3_vl_with_gptq.py index b3bb98a6..11072046 100644 --- a/tico/quantization/wrapq/examples/quantize_qwen3_vl_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_qwen3_vl_with_gptq.py @@ -38,6 +38,7 @@ ) from tico.quantization.evaluation.mmmu_eval_utils import ( evaluate_mmmu, + MMMU_DATASETS, print_mmmu_results, ) from tico.quantization.evaluation.vlm_eval_utils import ( @@ -396,15 +397,24 @@ def parse_args(): ) # MMMU evaluation arguments + parser.add_argument( + "--mmmu_dataset", + type=str, + choices=MMMU_DATASETS, + default=None, + help="MMMU dataset name.", + ) + parser.add_argument( "--mmmu_subjects", type=str, default=None, nargs="+", help=( - "Space-separated list of MMMU subjects to evaluate. Use 'mmmu' for all subjects." - "Use 'Accounting', 'Agriculture', 'Art', etc. for specific subjects." - "See https://huggingface.co/datasets/MMMU/MMMU for the full list." + "Space-separated list of MMMU subjects to evaluate. " + "Use 'Accounting', 'Agriculture', 'Art', etc. for specific subjects. " + "See https://huggingface.co/datasets/MMMU/MMMU for the full list. " + "If not specified, all subjects will be evaluated." ), ) parser.add_argument( @@ -1194,11 +1204,12 @@ def evaluate_original_model(model, processor, args): acc = get_hellaswag_accuracy(original_hellaswag_results) print(f"Accuracy: {acc['acc']:.4f}, Accuracy (norm): {acc['acc_norm']:.4f}") - if args.mmmu_subjects is not None: + if args.mmmu_dataset is not None: print("\n=== MMMU Evaluation (Original Model) ===") original_mmmu_results = evaluate_mmmu( model=model, processor=processor, + dataset=args.mmmu_dataset, subjects=args.mmmu_subjects, device=args.device, n_shots=args.mmmu_n_shots, @@ -1287,11 +1298,12 @@ def evaluate_quantized_model(model, processor, args, original_results=None) -> N acc = get_hellaswag_accuracy(quantized_hellaswag_results) print(f"Accuracy: {acc['acc']:.4f}, Accuracy (norm): {acc['acc_norm']:.4f}") - if args.mmmu_subjects is not None: + if args.mmmu_dataset is not None: print("\n=== MMMU Evaluation (Quantized Model) ===") quantized_mmmu_results = evaluate_mmmu( model=model, processor=processor, + dataset=args.mmmu_dataset, subjects=args.mmmu_subjects, device=args.device, n_shots=args.mmmu_n_shots,