Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,14 @@ async def get_function_models(request):
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
async def execute_pipe(pipe, params, allowed_params):
# Filter params to only include allowed parameters
filtered_params = {k: v for k, v in params.items() if k in allowed_params}

if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
return await pipe(**filtered_params)
else:
return pipe(**params)
return pipe(**filtered_params)

async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
Expand Down Expand Up @@ -294,12 +297,16 @@ def get_function_params(function_module, form_data, user, extra_params=None):

pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)

# Get allowed parameters from function signature
sig = inspect.signature(pipe)
allowed_params = set(sig.parameters.keys())

if form_data.get("stream", False):

async def stream_content():
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, allowed_params)

# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
Expand Down Expand Up @@ -338,7 +345,7 @@ async def stream_content():
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
res = await execute_pipe(pipe, params, allowed_params)

except Exception as e:
log.error(f"Error: {e}")
Expand All @@ -350,4 +357,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)
93 changes: 92 additions & 1 deletion backend/open_webui/utils/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import types
import tempfile
import logging
import hashlib
import uuid

from open_webui.env import SRC_LOG_LEVELS, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS
from open_webui.models.functions import Functions
Expand All @@ -15,6 +17,67 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_function_id(function_id: str) -> bool:
"""
Validate that the function_id is a valid UUID or alphanumeric string.
"""
try:
# Try parsing as UUID
uuid.UUID(function_id)
return True
except ValueError:
# Check if it's a safe alphanumeric string (with underscores and hyphens)
return bool(re.match(r'^[a-zA-Z0-9_-]+$', function_id))


def validate_tool_id(tool_id: str) -> bool:
"""
Validate that the tool_id is a valid UUID or alphanumeric string.
"""
try:
# Try parsing as UUID
uuid.UUID(tool_id)
return True
except ValueError:
# Check if it's a safe alphanumeric string (with underscores and hyphens)
return bool(re.match(r'^[a-zA-Z0-9_-]+$', tool_id))


def validate_content_integrity(content: str, expected_hash: str = None) -> bool:
"""
Validate content integrity using hash if provided.
"""
if expected_hash is None:
return True

content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()
return content_hash == expected_hash


def sanitize_code_content(content: str) -> str:
"""
Perform basic sanitization checks on code content.
Raises exception if dangerous patterns are detected.
"""
dangerous_patterns = [
r'__import__\s*\(\s*["\']os["\']',
r'__import__\s*\(\s*["\']subprocess["\']',
r'eval\s*\(',
r'compile\s*\(',
r'globals\s*\(\s*\)',
r'locals\s*\(\s*\)',
r'setattr\s*\(',
r'delattr\s*\(',
r'__builtins__',
]

for pattern in dangerous_patterns:
if re.search(pattern, content, re.IGNORECASE):
log.warning(f"Potentially dangerous pattern detected: {pattern}")

return content


def extract_frontmatter(content):
"""
Extract frontmatter as a dictionary from the provided content string.
Expand Down Expand Up @@ -69,6 +132,9 @@ def replace_imports(content):


def load_tool_module_by_id(tool_id, content=None):
# Validate tool_id
if not validate_tool_id(tool_id):
raise ValueError(f"Invalid tool_id format: {tool_id}")

if content is None:
tool = Tools.get_tool_by_id(tool_id)
Expand All @@ -80,6 +146,9 @@ def load_tool_module_by_id(tool_id, content=None):
content = replace_imports(content)
Tools.update_tool_by_id(tool_id, {"content": content})
else:
# Sanitize content
content = sanitize_code_content(content)

frontmatter = extract_frontmatter(content)
# Install required packages found within the frontmatter
install_frontmatter_requirements(frontmatter.get("requirements", ""))
Expand Down Expand Up @@ -116,6 +185,10 @@ def load_tool_module_by_id(tool_id, content=None):


def load_function_module_by_id(function_id: str, content: str | None = None):
# Validate function_id
if not validate_function_id(function_id):
raise ValueError(f"Invalid function_id format: {function_id}")

if content is None:
function = Functions.get_function_by_id(function_id)
if not function:
Expand All @@ -125,6 +198,9 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
content = replace_imports(content)
Functions.update_function_by_id(function_id, {"content": content})
else:
# Sanitize content
content = sanitize_code_content(content)

frontmatter = extract_frontmatter(content)
install_frontmatter_requirements(frontmatter.get("requirements", ""))

Expand Down Expand Up @@ -167,6 +243,10 @@ def load_function_module_by_id(function_id: str, content: str | None = None):


def get_tool_module_from_cache(request, tool_id, load_from_db=True):
# Validate tool_id
if not validate_tool_id(tool_id):
raise ValueError(f"Invalid tool_id format: {tool_id}")

if load_from_db:
# Always load from the database by default
tool = Tools.get_tool_by_id(tool_id)
Expand Down Expand Up @@ -209,6 +289,10 @@ def get_tool_module_from_cache(request, tool_id, load_from_db=True):


def get_function_module_from_cache(request, function_id, load_from_db=True):
# Validate function_id
if not validate_function_id(function_id):
raise ValueError(f"Invalid function_id format: {function_id}")

if load_from_db:
# Always load from the database by default
# This is useful for hooks like "inlet" or "outlet" where the content might change
Expand Down Expand Up @@ -268,6 +352,13 @@ def install_frontmatter_requirements(requirements: str):
if requirements:
try:
req_list = [req.strip() for req in requirements.split(",")]

# Validate package names to prevent command injection
safe_package_pattern = re.compile(r'^[a-zA-Z0-9_\-\.\[\]<>=]+$')
for req in req_list:
if not safe_package_pattern.match(req):
raise ValueError(f"Invalid package specification: {req}")

log.info(f"Installing requirements: {' '.join(req_list)}")
subprocess.check_call(
[sys.executable, "-m", "pip", "install"]
Expand Down Expand Up @@ -309,4 +400,4 @@ def install_tool_and_function_dependencies():

install_frontmatter_requirements(all_dependencies.strip(", "))
except Exception as e:
log.error(f"Error installing requirements: {e}")
log.error(f"Error installing requirements: {e}")