From 0108de3d9caa0d8ff68f9a627f1c8e87151f4aa1 Mon Sep 17 00:00:00 2001 From: "wangjiaju.716" Date: Mon, 25 Aug 2025 18:50:12 +0800 Subject: [PATCH] Add buildin tools:image_generate,image_edit,video_generate --- veadk/tools/builtin_tools/image_edit.py | 80 +++++++++++++ veadk/tools/builtin_tools/image_generate.py | 85 +++++++++++++ veadk/tools/builtin_tools/video_generate.py | 125 ++++++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 veadk/tools/builtin_tools/image_edit.py create mode 100644 veadk/tools/builtin_tools/image_generate.py create mode 100644 veadk/tools/builtin_tools/video_generate.py diff --git a/veadk/tools/builtin_tools/image_edit.py b/veadk/tools/builtin_tools/image_edit.py new file mode 100644 index 00000000..e11eb454 --- /dev/null +++ b/veadk/tools/builtin_tools/image_edit.py @@ -0,0 +1,80 @@ +from typing import Dict +from google.adk.tools import ToolContext +from google.genai import types +from volcenginesdkarkruntime import Ark +from veadk.config import getenv +import base64 + +client = Ark( + api_key=getenv("MODEL_IMAGE_API_KEY"), + base_url=getenv("MODEL_IMAGE_API_BASE"), +) + +async def image_edit( + origin_image: str, + image_name: str, + image_prompt: str, + response_format: str, + guidance_scale: float, + watermark: bool, + seed: int, + tool_context: ToolContext) -> Dict: + """Edit an image accoding to the prompt. + + Args: + origin_image: The url or the base64 string of the edited image. + image_name: The name of the generated image. + image_prompt: The prompt that describes the image. + response_format: str, b64_json or url, default url. + guidance_scale: default 2.5. + watermark: default True. + seed: default -1. + + """ + try: + response = client.images.generate( + model=getenv("MODEL_EDIT_NAME"), + image=origin_image, + prompt=image_prompt, + response_format=response_format, + guidance_scale=guidance_scale, + watermark=watermark, + seed=seed + ) + + if response.data and len(response.data) > 0: + for item in response.data: + if response_format == "url": + image = item.url + tool_context.state["generated_image_url"] = image + + elif response_format == "b64_json": + image = item.b64_json + image_bytes = base64.b64decode(image) + + tool_context.state["generated_image_url"] = ( + f"data:image/jpeg;base64,{image}" + ) + + report_artifact = types.Part.from_bytes( + data=image_bytes, mime_type="image/png" + ) + await tool_context.save_artifact(image_name, report_artifact) + logger.debug(f"Image saved as ADK artifact: {image_name}") + + return { + "status": "success", + "image_name": image_name, + "image": image + } + else: + error_details = f"No images returned by Doubao model: {response}" + logger.error(error_details) + return {"status": "error", "message": error_details} + + except Exception as e: + return { + "status": "error", + "message": f"Doubao image generation failed: {str(e)}", + } + diff --git a/veadk/tools/builtin_tools/image_generate.py b/veadk/tools/builtin_tools/image_generate.py new file mode 100644 index 00000000..b76c208a --- /dev/null +++ b/veadk/tools/builtin_tools/image_generate.py @@ -0,0 +1,85 @@ +from typing import Dict + +from google.genai import types +from google.adk.tools import ToolContext +from veadk.config import getenv +import base64 +from volcenginesdkarkruntime import Ark + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +client = Ark( + api_key=getenv("MODEL_IMAGE_API_KEY"), + base_url=getenv("MODEL_IMAGE_API_BASE"), +) + +async def image_generate( + image_name: str, + image_prompt: str, + response_format: str, + size: str, + guidance_scale: float, + watermark: bool, + seed: int, + tool_context: ToolContext) -> Dict: + """Generate an image accoding to the prompt. + + Args: + image_name: The name of the generated image. + image_prompt: The prompt that describes the image. + response_format: str, b64_json or url, default url. + size: default 1024x1024. + guidance_scale: default 2.5. + watermark: default True. + seed: default -1. + + """ + try: + response = client.images.generate( + model=getenv("MODEL_IMAGE_NAME"), + prompt=image_prompt, + response_format=response_format, + size=size, + guidance_scale=guidance_scale, + watermark=watermark, + seed=seed + ) + + if response.data and len(response.data) > 0: + for item in response.data: + if response_format == "url": + image = item.url + tool_context.state["generated_image_url"] = image + + elif response_format == "b64_json": + image = item.b64_json + image_bytes = base64.b64decode(image) + + tool_context.state["generated_image_url"] = ( + f"data:image/jpeg;base64,{image}" + ) + + report_artifact = types.Part.from_bytes( + data=image_bytes, mime_type="image/png" + ) + await tool_context.save_artifact(image_name, report_artifact) + logger.debug(f"Image saved as ADK artifact: {image_name}") + + return { + "status": "success", + "image_name": image_name, + "image": image + } + else: + error_details = f"No images returned by Doubao model: {response}" + logger.error(error_details) + return {"status": "error", "message": error_details} + + except Exception as e: + return { + "status": "error", + "message": f"Doubao image generation failed: {str(e)}", + } + \ No newline at end of file diff --git a/veadk/tools/builtin_tools/video_generate.py b/veadk/tools/builtin_tools/video_generate.py new file mode 100644 index 00000000..79a9b45e --- /dev/null +++ b/veadk/tools/builtin_tools/video_generate.py @@ -0,0 +1,125 @@ +from typing import Dict +from google.adk.tools import ToolContext +from volcenginesdkarkruntime import Ark +from veadk.config import getenv +import time +import traceback +import base64 + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +client = Ark( + api_key=getenv("MODEL_VIDEO_API_KEY"), + base_url=getenv("MODEL_VIDEO_API_BASE"), +) + +async def generate(tool_context, prompt, first_frame_image=None, last_frame_image=None): + try: + if first_frame_image is None: + logger.debug("text generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + ], + ) + elif last_frame_image is None: + logger.debug("first frame generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": first_frame_image}, + }, + ], + ) + else: + logger.debug("last frame generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": first_frame_image}, + "role": "first_frame", + }, + { + "type": "image_url", + "image_url": {"url": last_frame_image}, + "role": "last_frame", + }, + ], + ) + except: + traceback.print_exc() + raise + return response + +async def video_generate( + params: list, + tool_context: ToolContext) -> Dict: + """Generate video in batch according to the prompt. + + Args: + params: + video_name: The name of the generated video. + first_frame: The first frame of the video, url or base64 string, or None. + last_frame:The last frame of the video, url or base64 string, or None. + prompt:The prompt of the video. + """ + batch_size = 10 + success_list = [] + error_list = [] + for start_idx in range(0, len(params), batch_size): + batch = params[start_idx : start_idx + batch_size] + task_dict = {} + for item in batch: + video_name = item["video_name"] + first_frame = item["first_frame"] + last_frame = item["last_frame"] + prompt = item["prompt"] + try: + if not first_frame: + response = await generate(tool_context, prompt) + elif not last_frame: + response = await generate(tool_context, prompt, first_frame) + else: + response = await generate( + tool_context, prompt, first_frame, last_frame + ) + task_dict[response.id] = video_name + except Exception: + traceback.print_exc() + while True: + task_list = list(task_dict.keys()) + if len(task_list) == 0: + break + for task_id in task_list: + result = client.content_generation.tasks.get(task_id=task_id) + status = result.status + if status == "succeeded": + logger.debug("----- task succeeded -----") + tool_context.state[f"{task_dict[task_id]}_video_url"] = result.content.video_url + success_list.append({task_dict[task_id]: result.content.video_url}) + task_dict.pop(task_id, None) + elif status == "failed": + logger.debug("----- task failed -----") + logger.debug(f"Error: {result.error}") + error_list.append(task_dict[task_id]) + task_dict.pop(task_id, None) + else: + logger.debug(f"Current status: {status}, Retrying after 10 seconds...") + time.sleep(10) + + if len(success_list) == 0: + return {"status": "error", "message": f"Following videos failed: {error_list}"} + else: + return { + "status": "success", + "message": f"Following videos generated: {success_list}\nFollowing videos failed: {error_list}", + } \ No newline at end of file