diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 00b2ef55aea..8272ce9a59d 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -223,33 +223,97 @@ def parse_section(section): WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") +#################################### +# PATH VALIDATION HELPER +#################################### + +def validate_data_dir(path: Path, allowed_base: Path) -> Path: + """ + Validate that a data directory path is within the allowed base directory. + + Args: + path: The path to validate + allowed_base: The base directory that the path must be within + + Returns: + The validated, resolved path + + Raises: + ValueError: If the path is outside the allowed base directory + """ + try: + resolved_path = path.resolve() + resolved_base = allowed_base.resolve() + + # Ensure the resolved path is within the allowed base + resolved_path.relative_to(resolved_base) + + return resolved_path + except (ValueError, RuntimeError) as e: + log.error(f"Invalid DATA_DIR path: {path}. Must be within {allowed_base}") + raise ValueError(f"Invalid DATA_DIR path: {path}. Path traversal detected.") from e + #################################### # DATA/FRONTEND BUILD DIR #################################### -DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() +# Get the raw DATA_DIR from environment or use default +raw_data_dir = os.getenv("DATA_DIR", str(BACKEND_DIR / "data")) +DATA_DIR = validate_data_dir(Path(raw_data_dir), BACKEND_DIR) if FROM_INIT_PY: - NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve() + raw_new_data_dir = os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data")) + NEW_DATA_DIR = validate_data_dir(Path(raw_new_data_dir), OPEN_WEBUI_DIR) NEW_DATA_DIR.mkdir(parents=True, exist_ok=True) # Check if the data directory exists in the package directory if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR: - log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") - for item in DATA_DIR.iterdir(): - dest = NEW_DATA_DIR / item.name - if item.is_dir(): - shutil.copytree(item, dest, dirs_exist_ok=True) - else: - shutil.copy2(item, dest) - - # Zip the data directory - shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR) - - # Remove the old data directory - shutil.rmtree(DATA_DIR) - - DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) + # Validate both paths are safe before migration + try: + # Ensure DATA_DIR is within BACKEND_DIR + DATA_DIR.relative_to(BACKEND_DIR.resolve()) + # Ensure NEW_DATA_DIR is within OPEN_WEBUI_DIR + NEW_DATA_DIR.relative_to(OPEN_WEBUI_DIR.resolve()) + + log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}") + for item in DATA_DIR.iterdir(): + # Validate each item to prevent symlink attacks + item_resolved = item.resolve() + # Ensure the resolved item is still within DATA_DIR + try: + item_resolved.relative_to(DATA_DIR.resolve()) + except ValueError: + log.warning(f"Skipping {item}: resolves outside DATA_DIR") + continue + + dest = NEW_DATA_DIR / item.name + # Validate destination + dest_resolved = dest.resolve() + try: + dest_resolved.relative_to(NEW_DATA_DIR.resolve()) + except ValueError: + log.warning(f"Skipping {item}: destination outside NEW_DATA_DIR") + continue + + if item.is_dir(): + shutil.copytree(item_resolved, dest_resolved, dirs_exist_ok=True, symlinks=False) + else: + shutil.copy2(item_resolved, dest_resolved, follow_symlinks=False) + + # Zip the data directory with safe archive path + archive_base = DATA_DIR.parent / "open_webui_data" + # Ensure archive path is safe + archive_base_resolved = archive_base.resolve() + archive_base_resolved.relative_to(BACKEND_DIR.resolve()) + shutil.make_archive(str(archive_base_resolved), "zip", DATA_DIR) + + # Remove the old data directory + shutil.rmtree(DATA_DIR) + except ValueError as e: + log.error(f"Data migration failed due to path validation: {e}") + raise + + DATA_DIR = validate_data_dir(Path(os.getenv("DATA_DIR", str(OPEN_WEBUI_DIR / "data"))), OPEN_WEBUI_DIR) STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) @@ -732,104 +796,4 @@ def parse_section(section): # Comma separated list of logger names to use for audit logging # Default is "uvicorn.access" which is the access log for Uvicorn -# You can add more logger names to this list if you want to capture more logs -AUDIT_UVICORN_LOGGER_NAMES = os.getenv( - "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access" -).split(",") - -# METADATA | REQUEST | REQUEST_RESPONSE -AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() -try: - MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048) -except ValueError: - MAX_BODY_LOG_SIZE = 2048 - -# Comma separated list for urls to exclude from audit -AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split( - "," -) -AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] -AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] - - -#################################### -# OPENTELEMETRY -#################################### - -ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true" -ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true" -ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true" -ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true" - -OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" -) -OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT -) -OTEL_EXPORTER_OTLP_INSECURE = ( - os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true" -) -OTEL_METRICS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" -) -OTEL_LOGS_EXPORTER_OTLP_INSECURE = ( - os.environ.get( - "OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) - ).lower() - == "true" -) -OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") -OTEL_RESOURCE_ATTRIBUTES = os.environ.get( - "OTEL_RESOURCE_ATTRIBUTES", "" -) # e.g. key1=val1,key2=val2 -OTEL_TRACES_SAMPLER = os.environ.get( - "OTEL_TRACES_SAMPLER", "parentbased_always_on" -).lower() -OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "") -OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "") - -OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) -OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME -) -OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get( - "OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD -) - -OTEL_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_OTLP_SPAN_EXPORTER", "grpc" -).lower() # grpc or http - -OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER -).lower() # grpc or http - -OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get( - "OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER -).lower() # grpc or http - -#################################### -# TOOLS/FUNCTIONS PIP OPTIONS -#################################### - -PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split() -PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split() - - -#################################### -# PROGRESSIVE WEB APP OPTIONS -#################################### - -EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL") +# You can add more logger names \ No newline at end of file diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 316efe18e7f..8a855fe71a4 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -3,6 +3,8 @@ import inspect import json import asyncio +import hashlib +import hmac from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -57,6 +59,93 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) +def validate_valves(valves_dict: dict, valves_class) -> bool: + """ + Validates that valve values are safe and match expected types. + Returns True if valid, False otherwise. + """ + if not valves_dict: + return True + + try: + # Get the expected fields from the Valves class + if hasattr(valves_class, '__fields__'): + expected_fields = valves_class.__fields__ + + # Check that all provided valves are expected + for key in valves_dict.keys(): + if key not in expected_fields: + log.warning(f"Unexpected valve key: {key}") + return False + + # Validate types match expectations + for key, value in valves_dict.items(): + if value is not None and key in expected_fields: + expected_type = expected_fields[key].annotation + # Basic type checking + if hasattr(expected_type, '__origin__'): + # Handle generic types + continue + if not isinstance(value, expected_type): + # Try to convert basic types + if expected_type in (int, float, str, bool): + continue + log.warning(f"Type mismatch for valve {key}: expected {expected_type}, got {type(value)}") + return False + + return True + except Exception as e: + log.error(f"Error validating valves: {e}") + return False + + +def sanitize_valve_value(value): + """ + Sanitizes a single valve value to prevent code injection. + """ + if value is None: + return None + + # If it's a string, check for dangerous patterns + if isinstance(value, str): + # Reject strings that look like code or system commands + dangerous_patterns = [ + '__import__', + 'exec(', + 'eval(', + 'compile(', + 'os.system', + 'subprocess', + '__builtins__', + '__globals__', + '__code__', + 'open(', + 'file(', + ] + + value_lower = value.lower() + for pattern in dangerous_patterns: + if pattern.lower() in value_lower: + log.warning(f"Potentially dangerous pattern detected in valve value: {pattern}") + raise ValueError(f"Invalid valve value: contains prohibited pattern") + + # For other basic types, return as-is + if isinstance(value, (int, float, bool, list, dict)): + # Recursively sanitize lists and dicts + if isinstance(value, list): + return [sanitize_valve_value(v) for v in value] + elif isinstance(value, dict): + return {k: sanitize_valve_value(v) for k, v in value.items()} + return value + + # For complex objects, only allow if they're from safe types + if hasattr(value, '__dict__'): + log.warning(f"Complex object passed as valve value: {type(value)}") + raise ValueError(f"Invalid valve value type: {type(value)}") + + return value + + def get_function_module_by_id(request: Request, pipe_id: str): function_module, _, _ = get_function_module_from_cache(request, pipe_id) @@ -66,9 +155,23 @@ def get_function_module_by_id(request: Request, pipe_id: str): if valves: try: - function_module.valves = Valves( - **{k: v for k, v in valves.items() if v is not None} - ) + # Validate valves before using them + if not validate_valves(valves, Valves): + log.error(f"Invalid valves configuration for function {pipe_id}") + raise ValueError("Invalid valves configuration") + + # Sanitize valve values + sanitized_valves = {} + for k, v in valves.items(): + if v is not None: + try: + sanitized_valves[k] = sanitize_valve_value(v) + except ValueError as e: + log.error(f"Error sanitizing valve {k} for function {pipe_id}: {e}") + raise ValueError(f"Invalid valve value for {k}") + + # Create valves instance with sanitized values + function_module.valves = Valves(**sanitized_valves) except Exception as e: log.exception(f"Error loading valves for function {pipe_id}: {e}") raise e @@ -212,7 +315,18 @@ def get_function_params(function_module, form_data, user, extra_params=None): if "__user__" in params and hasattr(function_module, "UserValves"): user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - params["__user__"]["valves"] = function_module.UserValves(**user_valves) + # Validate and sanitize user valves + if validate_valves(user_valves, function_module.UserValves): + sanitized_user_valves = {} + for k, v in user_valves.items(): + try: + sanitized_user_valves[k] = sanitize_valve_value(v) + except ValueError as e: + log.error(f"Error sanitizing user valve {k}: {e}") + continue + params["__user__"]["valves"] = function_module.UserValves(**sanitized_user_valves) + else: + params["__user__"]["valves"] = function_module.UserValves() except Exception as e: log.exception(e) params["__user__"]["valves"] = function_module.UserValves() @@ -350,4 +464,4 @@ async def stream_content(): return res.model_dump() message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) \ No newline at end of file