diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..293b9821bac 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,6 +3,7 @@ import inspect import json import asyncio +import re from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -57,7 +58,79 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_pipe_id(pipe_id: str) -> bool: + """ + Validate that the pipe_id is safe and belongs to an authorized function. + Only allow alphanumeric characters, hyphens, underscores, and dots. + """ + if not pipe_id or not isinstance(pipe_id, str): + return False + + # Only allow alphanumeric, hyphens, underscores, and dots + if not re.match(r'^[a-zA-Z0-9._-]+$', pipe_id): + return False + + # Prevent path traversal attempts + if '..' in pipe_id or pipe_id.startswith('.') or pipe_id.startswith('/'): + return False + + return True + + +def validate_pipe_params(params: dict) -> dict: + """ + Validate and sanitize parameters passed to pipe functions. + Remove any potentially dangerous keys and ensure data types are safe. + """ + if not isinstance(params, dict): + raise ValueError("Parameters must be a dictionary") + + # Create a copy to avoid modifying the original + safe_params = {} + + # Whitelist of allowed parameter keys + allowed_keys = { + 'body', '__event_emitter__', '__event_call__', '__chat_id__', + '__session_id__', '__message_id__', '__task__', '__task_body__', + '__files__', '__user__', '__metadata__', '__oauth_token__', + '__request__', '__tools__', '__model__', '__messages__' + } + + for key, value in params.items(): + if key in allowed_keys: + safe_params[key] = value + else: + log.warning(f"Filtered out unexpected parameter key: {key}") + + return safe_params + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate pipe_id to prevent code injection + if not validate_pipe_id(pipe_id): + log.error(f"Invalid pipe_id format: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid function ID format" + ) + + # Verify that the function exists in the database + function = Functions.get_function_by_id(pipe_id.split('.')[0]) + if not function: + log.error(f"Function not found: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function not found" + ) + + # Verify the function is active + if not function.is_active: + log.error(f"Function is not active: {pipe_id}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Function is not active" + ) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -160,10 +233,48 @@ async def generate_function_chat_completion( request, form_data, user, models: dict = {} ): async def execute_pipe(pipe, params): - if inspect.iscoroutinefunction(pipe): - return await pipe(**params) - else: - return pipe(**params) + """ + Execute pipe function with validated parameters and security controls. + """ + # Validate that pipe is actually a callable function + if not callable(pipe): + raise ValueError("Invalid pipe: must be a callable function") + + # Validate and sanitize parameters + try: + safe_params = validate_pipe_params(params) + except Exception as e: + log.error(f"Parameter validation failed: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid parameters provided to pipe function" + ) + + # Execute with timeout to prevent infinite loops + try: + if inspect.iscoroutinefunction(pipe): + # Set a reasonable timeout (e.g., 300 seconds) + result = await asyncio.wait_for( + pipe(**safe_params), + timeout=300.0 + ) + else: + # For sync functions, run in executor to avoid blocking + loop = asyncio.get_event_loop() + result = await asyncio.wait_for( + loop.run_in_executor(None, lambda: pipe(**safe_params)), + timeout=300.0 + ) + return result + except asyncio.TimeoutError: + log.error(f"Pipe execution timed out") + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Pipe execution timed out" + ) + except Exception as e: + log.error(f"Pipe execution failed: {e}") + raise async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): @@ -350,4 +461,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file