Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions OmniGen2-RL/reward_server/reward_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from typing import List, Dict, Any, Tuple
import argparse
import pickle
import requests
import json
import time
import logging
from flask import Flask, request, jsonify
from concurrent.futures import ThreadPoolExecutor
from safe_serialization import safe_dumps, safe_loads
from collections import defaultdict
import math
import yaml
Expand Down Expand Up @@ -122,15 +122,15 @@ def _send_request_to_worker(
try:
response = requests.post(
server_url,
data=pickle.dumps(batch_data),
headers={"Content-Type": "application/octet-stream"},
data=safe_dumps(batch_data),
headers={"Content-Type": "application/json"},
timeout=600, # 300 seconds timeout
)
response.raise_for_status() # Raise exception for 4xx or 5xx status codes
return pickle.loads(response.content)
return safe_loads(response.content)
except requests.exceptions.RequestException as e:
logger.error(f"Request to server {server_url} failed: {e}")
except pickle.PickleError as e:
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Failed to parse response from {server_url}: {e}")
return None # Return None to indicate failure

Expand Down Expand Up @@ -205,7 +205,7 @@ def process_batch(

def prepare_request_data(request_body: bytes) -> Tuple[List, List, str, Dict]:
"""Parse request body and add original index to meta data."""
data = pickle.loads(request_body)
data = safe_loads(request_body)
input_images = data["input_images"]
output_image = data["output_image"]
meta_datas = data["meta_datas"]
Expand Down Expand Up @@ -254,7 +254,7 @@ def evaluate():
f"Evaluation complete! Total time: {total_time:.3f}s ({total_time / original_batch_size * 1000:.1f} ms/image)"
)

return pickle.dumps(ordered_result)
return safe_dumps(ordered_result), 200, {"Content-Type": "application/json"}


def main():
Expand Down
8 changes: 4 additions & 4 deletions OmniGen2-RL/reward_server/reward_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import List, Optional
import argparse
import pickle
import json
import os
import warnings
Expand All @@ -16,6 +15,7 @@

from flask import Flask, request, jsonify
from PIL import Image
from safe_serialization import safe_dumps, safe_loads

from editscore import EditScore
import yaml
Expand Down Expand Up @@ -99,14 +99,14 @@ def vlm_worker(scorer: VLMScorer):
"group_strict_reward": {_meta_data.get("tag", "vlm"): reward},
}
)
results[task_id] = pickle.dumps(result_payload)
results[task_id] = safe_dumps(result_payload)

except Exception as e:
print(f"❌ Worker thread error while processing task {task_id[:8]}: {e}")
import traceback
traceback.print_exc()
error_result = {"error": f"Internal server error: {e}"}
results[task_id] = pickle.dumps(error_result)
results[task_id] = safe_dumps(error_result)
finally:
request_queue.task_done()

Expand All @@ -115,7 +115,7 @@ def vlm_worker(scorer: VLMScorer):
def parse_and_validate_request(raw_data: bytes) -> Tuple[List[Image.Image], Image.Image, Dict, str]:
"""Parse request data, validate and convert to required format."""
try:
data = pickle.loads(raw_data)
data = safe_loads(raw_data)
input_images_datas = data['input_images']
output_image_datas = data['output_image']
meta_data = data['meta_data']
Expand Down
53 changes: 53 additions & 0 deletions OmniGen2-RL/reward_server/safe_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Safe serialization utilities for reward server communication.

Uses JSON with base64-encoded images for structured, portable data exchange.
"""

import io
import json
import base64
from typing import Any

from PIL import Image

_IMAGE_MARKER = "__pil_image__"
_TUPLE_MARKER = "__tuple__"


def _encode(obj: Any) -> Any:
if isinstance(obj, Image.Image):
buf = io.BytesIO()
obj.save(buf, format="PNG")
return {_IMAGE_MARKER: base64.b64encode(buf.getvalue()).decode("ascii")}
if isinstance(obj, tuple):
return {_TUPLE_MARKER: [_encode(item) for item in obj]}
if isinstance(obj, list):
return [_encode(item) for item in obj]
if isinstance(obj, dict):
return {k: _encode(v) for k, v in obj.items()}
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
raise TypeError(f"Unsupported type for safe serialization: {type(obj)}")


def _decode(obj: Any) -> Any:
if isinstance(obj, dict):
if _IMAGE_MARKER in obj:
return Image.open(io.BytesIO(base64.b64decode(obj[_IMAGE_MARKER])))
if _TUPLE_MARKER in obj:
return tuple(_decode(item) for item in obj[_TUPLE_MARKER])
return {k: _decode(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_decode(item) for item in obj]
return obj


def safe_dumps(data: Any) -> bytes:
"""Serialize data safely using JSON with base64-encoded images."""
return json.dumps(_encode(data), separators=(",", ":")).encode("utf-8")


def safe_loads(data: bytes) -> Any:
"""Deserialize data safely from JSON with base64-encoded images."""
return _decode(json.loads(data))