From ba185851cf65db56163d33f94591ebe500841fb6 Mon Sep 17 00:00:00 2001 From: "wangjiaju.716" Date: Mon, 25 Aug 2025 19:59:25 +0800 Subject: [PATCH] Add builtin tools: image generate, image edit, video generate --- veadk/tools/builtin_tools/image_edit.py | 94 +++++++++++++ veadk/tools/builtin_tools/image_generate.py | 95 +++++++++++++ veadk/tools/builtin_tools/video_generate.py | 141 ++++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 veadk/tools/builtin_tools/image_edit.py create mode 100644 veadk/tools/builtin_tools/image_generate.py create mode 100644 veadk/tools/builtin_tools/video_generate.py diff --git a/veadk/tools/builtin_tools/image_edit.py b/veadk/tools/builtin_tools/image_edit.py new file mode 100644 index 00000000..92d52f9a --- /dev/null +++ b/veadk/tools/builtin_tools/image_edit.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from google.adk.tools import ToolContext +from google.genai import types +from volcenginesdkarkruntime import Ark +from veadk.config import getenv +import base64 + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +client = Ark( + api_key=getenv("MODEL_IMAGE_API_KEY"), + base_url=getenv("MODEL_IMAGE_API_BASE"), +) + + +async def image_edit( + origin_image: str, + image_name: str, + image_prompt: str, + response_format: str, + guidance_scale: float, + watermark: bool, + seed: int, + tool_context: ToolContext, +) -> Dict: + """Edit an image accoding to the prompt. + + Args: + origin_image: The url or the base64 string of the edited image. + image_name: The name of the generated image. + image_prompt: The prompt that describes the image. + response_format: str, b64_json or url, default url. + guidance_scale: default 2.5. + watermark: default True. + seed: default -1. + + """ + try: + response = client.images.generate( + model=getenv("MODEL_EDIT_NAME"), + image=origin_image, + prompt=image_prompt, + response_format=response_format, + guidance_scale=guidance_scale, + watermark=watermark, + seed=seed, + ) + + if response.data and len(response.data) > 0: + for item in response.data: + if response_format == "url": + image = item.url + tool_context.state["generated_image_url"] = image + + elif response_format == "b64_json": + image = item.b64_json + image_bytes = base64.b64decode(image) + + tool_context.state["generated_image_url"] = ( + f"data:image/jpeg;base64,{image}" + ) + + report_artifact = types.Part.from_bytes( + data=image_bytes, mime_type="image/png" + ) + await tool_context.save_artifact(image_name, report_artifact) + logger.debug(f"Image saved as ADK artifact: {image_name}") + + return {"status": "success", "image_name": image_name, "image": image} + else: + error_details = f"No images returned by Doubao model: {response}" + logger.error(error_details) + return {"status": "error", "message": error_details} + + except Exception as e: + return { + "status": "error", + "message": f"Doubao image generation failed: {str(e)}", + } diff --git a/veadk/tools/builtin_tools/image_generate.py b/veadk/tools/builtin_tools/image_generate.py new file mode 100644 index 00000000..b069078e --- /dev/null +++ b/veadk/tools/builtin_tools/image_generate.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from google.genai import types +from google.adk.tools import ToolContext +from veadk.config import getenv +import base64 +from volcenginesdkarkruntime import Ark + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +client = Ark( + api_key=getenv("MODEL_IMAGE_API_KEY"), + base_url=getenv("MODEL_IMAGE_API_BASE"), +) + + +async def image_generate( + image_name: str, + image_prompt: str, + response_format: str, + size: str, + guidance_scale: float, + watermark: bool, + seed: int, + tool_context: ToolContext, +) -> Dict: + """Generate an image accoding to the prompt. + + Args: + image_name: The name of the generated image. + image_prompt: The prompt that describes the image. + response_format: str, b64_json or url, default url. + size: default 1024x1024. + guidance_scale: default 2.5. + watermark: default True. + seed: default -1. + + """ + try: + response = client.images.generate( + model=getenv("MODEL_IMAGE_NAME"), + prompt=image_prompt, + response_format=response_format, + size=size, + guidance_scale=guidance_scale, + watermark=watermark, + seed=seed, + ) + + if response.data and len(response.data) > 0: + for item in response.data: + if response_format == "url": + image = item.url + tool_context.state["generated_image_url"] = image + + elif response_format == "b64_json": + image = item.b64_json + image_bytes = base64.b64decode(image) + + tool_context.state["generated_image_url"] = ( + f"data:image/jpeg;base64,{image}" + ) + + report_artifact = types.Part.from_bytes( + data=image_bytes, mime_type="image/png" + ) + await tool_context.save_artifact(image_name, report_artifact) + logger.debug(f"Image saved as ADK artifact: {image_name}") + + return {"status": "success", "image_name": image_name, "image": image} + else: + error_details = f"No images returned by Doubao model: {response}" + logger.error(error_details) + return {"status": "error", "message": error_details} + + except Exception as e: + return { + "status": "error", + "message": f"Doubao image generation failed: {str(e)}", + } diff --git a/veadk/tools/builtin_tools/video_generate.py b/veadk/tools/builtin_tools/video_generate.py new file mode 100644 index 00000000..d3212f85 --- /dev/null +++ b/veadk/tools/builtin_tools/video_generate.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from google.adk.tools import ToolContext +from volcenginesdkarkruntime import Ark +from veadk.config import getenv +import time +import traceback + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +client = Ark( + api_key=getenv("MODEL_VIDEO_API_KEY"), + base_url=getenv("MODEL_VIDEO_API_BASE"), +) + + +async def generate(tool_context, prompt, first_frame_image=None, last_frame_image=None): + try: + if first_frame_image is None: + logger.debug("text generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + ], + ) + elif last_frame_image is None: + logger.debug("first frame generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": first_frame_image}, + }, + ], + ) + else: + logger.debug("last frame generation") + response = client.content_generation.tasks.create( + model=getenv("MODEL_VIDEO_NAME"), + content=[ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": first_frame_image}, + "role": "first_frame", + }, + { + "type": "image_url", + "image_url": {"url": last_frame_image}, + "role": "last_frame", + }, + ], + ) + except: + traceback.print_exc() + raise + return response + + +async def video_generate(params: list, tool_context: ToolContext) -> Dict: + """Generate video in batch according to the prompt. + + Args: + params: + video_name: The name of the generated video. + first_frame: The first frame of the video, url or base64 string, or None. + last_frame:The last frame of the video, url or base64 string, or None. + prompt:The prompt of the video. + """ + batch_size = 10 + success_list = [] + error_list = [] + for start_idx in range(0, len(params), batch_size): + batch = params[start_idx : start_idx + batch_size] + task_dict = {} + for item in batch: + video_name = item["video_name"] + first_frame = item["first_frame"] + last_frame = item["last_frame"] + prompt = item["prompt"] + try: + if not first_frame: + response = await generate(tool_context, prompt) + elif not last_frame: + response = await generate(tool_context, prompt, first_frame) + else: + response = await generate( + tool_context, prompt, first_frame, last_frame + ) + task_dict[response.id] = video_name + except Exception: + traceback.print_exc() + while True: + task_list = list(task_dict.keys()) + if len(task_list) == 0: + break + for task_id in task_list: + result = client.content_generation.tasks.get(task_id=task_id) + status = result.status + if status == "succeeded": + logger.debug("----- task succeeded -----") + tool_context.state[f"{task_dict[task_id]}_video_url"] = ( + result.content.video_url + ) + success_list.append({task_dict[task_id]: result.content.video_url}) + task_dict.pop(task_id, None) + elif status == "failed": + logger.debug("----- task failed -----") + logger.debug(f"Error: {result.error}") + error_list.append(task_dict[task_id]) + task_dict.pop(task_id, None) + else: + logger.debug( + f"Current status: {status}, Retrying after 10 seconds..." + ) + time.sleep(10) + + if len(success_list) == 0: + return {"status": "error", "message": f"Following videos failed: {error_list}"} + else: + return { + "status": "success", + "message": f"Following videos generated: {success_list}\nFollowing videos failed: {error_list}", + }