From 29f720b19a354a156f7fa88d6eb580da3b1db2a8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Feb 2026 16:04:04 +0000 Subject: [PATCH] fix: Security vulnerability fixes Automated fixes by UnitOneFlow Security Guard. Vulnerabilities addressed: 5 See security-report.json for details. --- backend/open_webui/functions.py | 164 ++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 9 deletions(-) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..1b1ff0d9901 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,7 +3,7 @@ import inspect import json import asyncio - +import hashlib from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator from fastapi import ( @@ -57,7 +57,116 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_function_access(request: Request, pipe_id: str, user: UserModel = None): + """ + Validate that the user has access to the function and that the function exists. + """ + function = Functions.get_function_by_id(pipe_id) + + if not function: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND + ) + + if not function.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) + + # Verify function integrity + if hasattr(function, 'content') and hasattr(function, 'meta'): + expected_hash = function.meta.get('content_hash') + if expected_hash: + actual_hash = hashlib.sha256(function.content.encode()).hexdigest() + if actual_hash != expected_hash: + log.error(f"Function {pipe_id} failed integrity check") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Function integrity check failed" + ) + + return function + + +def sanitize_value(value, max_depth=5, current_depth=0): + """ + Sanitize values to prevent code injection. + Only allows primitive types and basic structures. + """ + if current_depth > max_depth: + log.warning(f"Maximum nesting depth exceeded during sanitization") + return None + + if value is None: + return None + + if isinstance(value, (str, int, float, bool)): + return value + + if isinstance(value, list): + return [sanitize_value(v, max_depth, current_depth + 1) for v in value] + + if isinstance(value, dict): + return { + str(k): sanitize_value(v, max_depth, current_depth + 1) + for k, v in value.items() + if isinstance(k, (str, int, float)) + } + + log.warning(f"Skipping unsupported type during sanitization: {type(value)}") + return None + + +def sanitize_params(params): + """ + Sanitize all parameters to prevent code injection attacks. + """ + sanitized = {} + + for key, value in params.items(): + # Skip special internal parameters + if key.startswith('__') and key.endswith('__'): + sanitized[key] = value + continue + + # Handle body parameter specially as it's a dict + if key == 'body' and isinstance(value, dict): + sanitized_body = {} + for k, v in value.items(): + sanitized_body[k] = sanitize_value(v) + sanitized[key] = sanitized_body + else: + sanitized[key] = sanitize_value(value) + + return sanitized + + +def validate_valves_schema(valves_class, valves_data): + """ + Validate that user valves data only contains fields defined in the Valves schema. + This prevents arbitrary object injection through extra fields. + """ + if not hasattr(valves_class, '__annotations__'): + return {} + + allowed_fields = set(valves_class.__annotations__.keys()) + validated_data = {} + + for key, value in valves_data.items(): + if key in allowed_fields: + validated_data[key] = value + else: + log.warning(f"Ignoring unexpected valve field: {key}") + + return validated_data + + def get_function_module_by_id(request: Request, pipe_id: str): + # Validate function access and integrity + validate_function_access(request, pipe_id) + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): @@ -66,12 +175,23 @@ def get_function_module_by_id(request: Request, pipe_id: str): if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + # Sanitize valve values to prevent injection + sanitized_valves = {} + for k, v in valves.items(): + sanitized_v = sanitize_value(v) + if sanitized_v is not None or v is None: + sanitized_valves[k] = sanitized_v + + # Validate schema to prevent object injection + validated_valves = validate_valves_schema(Valves, sanitized_valves) + + function_module.valves = Valves(**validated_valves) except Exception as e: log.exception(f"Error loading valves for function {pipe_id}: {e}") - raise e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error loading function configuration: {str(e)}" + ) else: function_module.valves = Valves() @@ -160,10 +280,23 @@ async def generate_function_chat_completion( request, form_data, user, models: dict = {} ): async def execute_pipe(pipe, params): + # Sanitize params before execution to prevent code injection + sanitized_params = sanitize_params(params) + + # Validate pipe signature to ensure we're only passing expected parameters + sig = inspect.signature(pipe) + validated_params = {} + + for key, value in sanitized_params.items(): + if key in sig.parameters: + validated_params[key] = value + else: + log.warning(f"Skipping unexpected parameter: {key}") + if inspect.iscoroutinefunction(pipe): - return await pipe(**params) + return await pipe(**validated_params) else: - return pipe(**params) + return pipe(**validated_params) async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): @@ -212,7 +345,20 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + # Sanitize user valve values + sanitized_user_valves = {} + for k, v in user_valves.items(): + sanitized_v = sanitize_value(v) + if sanitized_v is not None or v is None: + sanitized_user_valves[k] = sanitized_v + + # Validate schema to prevent object injection + validated_user_valves = validate_valves_schema( + function_module.UserValves, + sanitized_user_valves + ) + + params["__user__"]["valves"] = function_module.UserValves(**validated_user_valves) except Exception as e: log.exception(e) params["__user__"]["valves"] = function_module.UserValves() @@ -350,4 +496,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file