From 7949ef43e5de92e922d6013cdcb5355460d16ebb Mon Sep 17 00:00:00 2001 From: Valentin Lobstein Date: Wed, 18 Mar 2026 21:12:28 +0100 Subject: [PATCH 1/2] Fix: Replace unsafe pickle deserialization with JSON-based serialization pickle.loads() on untrusted HTTP input allows arbitrary code execution (CWE-502). Replace all pickle serialization in reward_proxy.py and reward_server.py with a safe JSON + base64 format via safe_serialization.py. Ref: CVE-2026-25873 --- OmniGen2-RL/reward_server/reward_proxy.py | 14 ++--- OmniGen2-RL/reward_server/reward_server.py | 8 +-- .../reward_server/safe_serialization.py | 54 +++++++++++++++++++ 3 files changed, 65 insertions(+), 11 deletions(-) create mode 100644 OmniGen2-RL/reward_server/safe_serialization.py diff --git a/OmniGen2-RL/reward_server/reward_proxy.py b/OmniGen2-RL/reward_server/reward_proxy.py index 4e85c7d..554e53a 100644 --- a/OmniGen2-RL/reward_server/reward_proxy.py +++ b/OmniGen2-RL/reward_server/reward_proxy.py @@ -2,13 +2,13 @@ from typing import List, Dict, Any, Tuple import argparse -import pickle import requests import json import time import logging from flask import Flask, request, jsonify from concurrent.futures import ThreadPoolExecutor +from safe_serialization import safe_dumps, safe_loads from collections import defaultdict import math import yaml @@ -122,15 +122,15 @@ def _send_request_to_worker( try: response = requests.post( server_url, - data=pickle.dumps(batch_data), - headers={"Content-Type": "application/octet-stream"}, + data=safe_dumps(batch_data), + headers={"Content-Type": "application/json"}, timeout=600, # 300 seconds timeout ) response.raise_for_status() # Raise exception for 4xx or 5xx status codes - return pickle.loads(response.content) + return safe_loads(response.content) except requests.exceptions.RequestException as e: logger.error(f"Request to server {server_url} failed: {e}") - except pickle.PickleError as e: + except (json.JSONDecodeError, KeyError, TypeError) as e: logger.error(f"Failed to parse response from {server_url}: {e}") return None # Return None to indicate failure @@ -205,7 +205,7 @@ def process_batch( def prepare_request_data(request_body: bytes) -> Tuple[List, List, str, Dict]: """Parse request body and add original index to meta data.""" - data = pickle.loads(request_body) + data = safe_loads(request_body) input_images = data["input_images"] output_image = data["output_image"] meta_datas = data["meta_datas"] @@ -254,7 +254,7 @@ def evaluate(): f"Evaluation complete! Total time: {total_time:.3f}s ({total_time / original_batch_size * 1000:.1f} ms/image)" ) - return pickle.dumps(ordered_result) + return safe_dumps(ordered_result), 200, {"Content-Type": "application/json"} def main(): diff --git a/OmniGen2-RL/reward_server/reward_server.py b/OmniGen2-RL/reward_server/reward_server.py index efd45c5..295427e 100644 --- a/OmniGen2-RL/reward_server/reward_server.py +++ b/OmniGen2-RL/reward_server/reward_server.py @@ -4,7 +4,6 @@ from typing import List, Optional import argparse -import pickle import json import os import warnings @@ -16,6 +15,7 @@ from flask import Flask, request, jsonify from PIL import Image +from safe_serialization import safe_dumps, safe_loads from editscore import EditScore import yaml @@ -99,14 +99,14 @@ def vlm_worker(scorer: VLMScorer): "group_strict_reward": {_meta_data.get("tag", "vlm"): reward}, } ) - results[task_id] = pickle.dumps(result_payload) + results[task_id] = safe_dumps(result_payload) except Exception as e: print(f"❌ Worker thread error while processing task {task_id[:8]}: {e}") import traceback traceback.print_exc() error_result = {"error": f"Internal server error: {e}"} - results[task_id] = pickle.dumps(error_result) + results[task_id] = safe_dumps(error_result) finally: request_queue.task_done() @@ -115,7 +115,7 @@ def vlm_worker(scorer: VLMScorer): def parse_and_validate_request(raw_data: bytes) -> Tuple[List[Image.Image], Image.Image, Dict, str]: """Parse request data, validate and convert to required format.""" try: - data = pickle.loads(raw_data) + data = safe_loads(raw_data) input_images_datas = data['input_images'] output_image_datas = data['output_image'] meta_data = data['meta_data'] diff --git a/OmniGen2-RL/reward_server/safe_serialization.py b/OmniGen2-RL/reward_server/safe_serialization.py new file mode 100644 index 0000000..cd38fdc --- /dev/null +++ b/OmniGen2-RL/reward_server/safe_serialization.py @@ -0,0 +1,54 @@ +""" +Safe serialization utilities for reward server communication. + +Replaces pickle-based serialization with JSON + base64-encoded images +to prevent arbitrary code execution via deserialization attacks (CWE-502). +""" + +import io +import json +import base64 +from typing import Any + +from PIL import Image + +_IMAGE_MARKER = "__pil_image__" +_TUPLE_MARKER = "__tuple__" + + +def _encode(obj: Any) -> Any: + if isinstance(obj, Image.Image): + buf = io.BytesIO() + obj.save(buf, format="PNG") + return {_IMAGE_MARKER: base64.b64encode(buf.getvalue()).decode("ascii")} + if isinstance(obj, tuple): + return {_TUPLE_MARKER: [_encode(item) for item in obj]} + if isinstance(obj, list): + return [_encode(item) for item in obj] + if isinstance(obj, dict): + return {k: _encode(v) for k, v in obj.items()} + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + raise TypeError(f"Unsupported type for safe serialization: {type(obj)}") + + +def _decode(obj: Any) -> Any: + if isinstance(obj, dict): + if _IMAGE_MARKER in obj: + return Image.open(io.BytesIO(base64.b64decode(obj[_IMAGE_MARKER]))) + if _TUPLE_MARKER in obj: + return tuple(_decode(item) for item in obj[_TUPLE_MARKER]) + return {k: _decode(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_decode(item) for item in obj] + return obj + + +def safe_dumps(data: Any) -> bytes: + """Serialize data safely using JSON with base64-encoded images.""" + return json.dumps(_encode(data), separators=(",", ":")).encode("utf-8") + + +def safe_loads(data: bytes) -> Any: + """Deserialize data safely from JSON with base64-encoded images.""" + return _decode(json.loads(data)) From 58f9533c8d857ec1f90a2dc04b31ea78ef452bb1 Mon Sep 17 00:00:00 2001 From: Valentin Lobstein Date: Wed, 18 Mar 2026 21:18:10 +0100 Subject: [PATCH 2/2] Fix: Clean up docstring in safe_serialization module --- OmniGen2-RL/reward_server/safe_serialization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/OmniGen2-RL/reward_server/safe_serialization.py b/OmniGen2-RL/reward_server/safe_serialization.py index cd38fdc..058e4ba 100644 --- a/OmniGen2-RL/reward_server/safe_serialization.py +++ b/OmniGen2-RL/reward_server/safe_serialization.py @@ -1,8 +1,7 @@ """ Safe serialization utilities for reward server communication. -Replaces pickle-based serialization with JSON + base64-encoded images -to prevent arbitrary code execution via deserialization attacks (CWE-502). +Uses JSON with base64-encoded images for structured, portable data exchange. """ import io