Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
team_alias: Optional[str] = None
team_tpm_limit: Optional[int] = None
team_rpm_limit: Optional[int] = None
team_max_parallel_requests: Optional[int] = None
team_max_budget: Optional[float] = None
team_models: List = []
team_blocked: bool = False
Expand Down Expand Up @@ -2076,6 +2077,7 @@ class UserAPIKeyAuth(
tpm_limit_per_model: Optional[Dict[str, int]] = None
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
user_max_parallel_requests: Optional[int] = None
user_email: Optional[str] = None
request_route: Optional[str] = None

Expand Down
38 changes: 38 additions & 0 deletions litellm/proxy/auth/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,44 @@ def get_key_model_tpm_limit(
return None


def get_key_model_max_parallel_requests(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the max_parallel_requests per model for a given API key.

Checks in order:
1. Key metadata (model_max_parallel_requests)
2. Key model_max_budget (max_parallel_requests per model)
3. Team metadata (model_max_parallel_requests)

Note: Uses `(dict or {}).get()` pattern to avoid short-circuit issues where
an empty dict would skip fallback checks in if/elif chains.
"""
# Check key metadata
result = (user_api_key_dict.metadata or {}).get("model_max_parallel_requests")
if result is not None:
return result

# Check model_max_budget
if user_api_key_dict.model_max_budget:
model_limits: Dict[str, int] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict):
limit = budget.get("max_parallel_requests")
if limit is not None:
model_limits[model] = limit
if model_limits:
return model_limits

# Check team metadata
result = (user_api_key_dict.team_metadata or {}).get("model_max_parallel_requests")
if result is not None:
return result

return None


def get_model_rate_limit_from_metadata(
user_api_key_dict: UserAPIKeyAuth,
metadata_accessor_key: Literal["team_metadata", "organization_metadata"],
Expand Down
19 changes: 16 additions & 3 deletions litellm/proxy/hooks/parallel_request_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_max_parallel_requests,
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
Expand Down Expand Up @@ -298,15 +299,18 @@ async def async_pre_call_hook( # noqa: PLR0915
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
or get_key_model_max_parallel_requests(user_api_key_dict) is not None
):
_model = data.get("model", None)
request_count_api_key = (
f"{api_key}::{_model}::{precise_minute}::request_count"
)
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
_max_parallel_for_key_model = get_key_model_max_parallel_requests(user_api_key_dict)
tpm_limit_for_model = None
rpm_limit_for_model = None
max_parallel_for_model = None

if _model is not None:
if _tpm_limit_for_key_model:
Expand All @@ -315,12 +319,15 @@ async def async_pre_call_hook( # noqa: PLR0915
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)

if _max_parallel_for_key_model:
max_parallel_for_model = _max_parallel_for_key_model.get(_model)

new_val = await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
max_parallel_requests=max_parallel_for_model or sys.maxsize,
current=cache_objects["request_count_api_key_model"],
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit_for_model or sys.maxsize,
Expand Down Expand Up @@ -351,10 +358,13 @@ async def async_pre_call_hook( # noqa: PLR0915
if user_id is not None:
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
user_max_parallel = user_api_key_dict.user_max_parallel_requests
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
if user_max_parallel is None:
user_max_parallel = sys.maxsize

request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
Expand All @@ -363,7 +373,7 @@ async def async_pre_call_hook( # noqa: PLR0915
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
max_parallel_requests=user_max_parallel,
current=cache_objects["request_count_user_id"],
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
Expand All @@ -378,11 +388,14 @@ async def async_pre_call_hook( # noqa: PLR0915
if team_id is not None:
team_tpm_limit = user_api_key_dict.team_tpm_limit
team_rpm_limit = user_api_key_dict.team_rpm_limit
team_max_parallel = user_api_key_dict.team_max_parallel_requests

if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
if team_max_parallel is None:
team_max_parallel = sys.maxsize

request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
Expand All @@ -391,7 +404,7 @@ async def async_pre_call_hook( # noqa: PLR0915
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
max_parallel_requests=team_max_parallel,
current=cache_objects["request_count_team_id"],
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
Expand Down
130 changes: 130 additions & 0 deletions tests/test_litellm/proxy/auth/test_max_parallel_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Tests for max_parallel_requests per model/user/team support.

Tests the helper function and rate limiter integration.
"""

import pytest
from typing import Dict, Optional

from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import get_key_model_max_parallel_requests


class TestGetKeyModelMaxParallelRequests:
"""Tests for get_key_model_max_parallel_requests helper function."""

def test_returns_none_when_no_limits_set(self):
"""Should return None when no max_parallel_requests limits are configured."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result is None

def test_extracts_from_metadata(self):
"""Should extract limits from key metadata."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={
"model_max_parallel_requests": {
"gpt-4": 5,
"gpt-3.5-turbo": 10,
}
},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result == {"gpt-4": 5, "gpt-3.5-turbo": 10}

def test_extracts_from_model_max_budget(self):
"""Should extract limits from model_max_budget when metadata not set."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={}, # Empty dict, not None
model_max_budget={
"gpt-4": {"max_parallel_requests": 3, "budget": 100.0},
"gpt-3.5-turbo": {"budget": 50.0}, # No max_parallel_requests
},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result == {"gpt-4": 3}

def test_extracts_from_team_metadata(self):
"""Should extract limits from team_metadata as fallback."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={}, # Empty dict, not None
team_metadata={
"model_max_parallel_requests": {
"gpt-4": 8,
}
},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result == {"gpt-4": 8}

def test_metadata_takes_precedence(self):
"""Key metadata should take precedence over team_metadata."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={"model_max_parallel_requests": {"gpt-4": 5}},
team_metadata={"model_max_parallel_requests": {"gpt-4": 10}},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result == {"gpt-4": 5}

def test_skips_none_values_in_model_max_budget(self):
"""Should skip models with None max_parallel_requests in budget."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
metadata={}, # Empty dict, not None
model_max_budget={
"gpt-4": {"max_parallel_requests": None},
"gpt-3.5-turbo": {"max_parallel_requests": 5},
},
)
result = get_key_model_max_parallel_requests(user_api_key_dict)
assert result == {"gpt-3.5-turbo": 5}


class TestUserMaxParallelRequests:
"""Tests for user-level max_parallel_requests field."""

def test_field_exists_on_user_api_key_auth(self):
"""UserAPIKeyAuth should have user_max_parallel_requests field."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
user_id="user-123",
user_max_parallel_requests=10,
)
assert user_api_key_dict.user_max_parallel_requests == 10

def test_field_defaults_to_none(self):
"""user_max_parallel_requests should default to None."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
user_id="user-123",
)
assert user_api_key_dict.user_max_parallel_requests is None


class TestTeamMaxParallelRequests:
"""Tests for team-level max_parallel_requests field."""

def test_field_exists_on_user_api_key_auth(self):
"""UserAPIKeyAuth should have team_max_parallel_requests field."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
team_id="team-456",
team_max_parallel_requests=50,
)
assert user_api_key_dict.team_max_parallel_requests == 50

def test_field_defaults_to_none(self):
"""team_max_parallel_requests should default to None."""
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test",
team_id="team-456",
)
assert user_api_key_dict.team_max_parallel_requests is None
Loading