Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from queue import Queue
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import uvicorn
import uvicorn.server
Expand Down Expand Up @@ -296,6 +296,8 @@ def __init__(self, lit_api: LitAPI, server: "LitServer"):

async def _prepare_request(self, request, request_type) -> dict:
"""Common request preparation logic."""
if isinstance(request, dict):
return request
if request_type == Request:
content_type = request.headers.get("Content-Type", "")
if content_type == "application/x-www-form-urlencoded" or content_type.startswith("multipart/form-data"):
Expand Down Expand Up @@ -1103,8 +1105,12 @@ def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type):
# Create handlers
handler = StreamingRequestHandler(lit_api, self) if lit_api.stream else RegularRequestHandler(lit_api, self)

# When no Pydantic model is annotated, use Dict[str, Any] so Swagger renders a request body form.
# FastAPI will parse the JSON body into a dict automatically in this case.
swagger_request_type = dict[str, Any] if request_type is Request else request_type

# Create endpoint function
async def endpoint_handler(request: request_type) -> response_type:
async def endpoint_handler(request: swagger_request_type) -> response_type:
return await handler.handle_request(request, request_type)

# Register endpoint
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ def test_pydantic():
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}


class NoAnnotationLitAPI(LitAPI):
def setup(self, device):
pass

def predict(self, request):
return {"output": request["input"] ** 2}


def test_swagger_request_body_without_decode_request_annotation():
"""Regression test for https://github.com/Lightning-AI/LitServe/issues/667.
When decode_request has no type annotation, Swagger should still expose a request body form."""
server = LitServer(NoAnnotationLitAPI(), accelerator="cpu", devices=1, timeout=5)
schema = server.app.openapi()
predict_post = schema["paths"]["/predict"]["post"]
assert "requestBody" in predict_post, "Swagger must expose a requestBody for /predict"
assert "application/json" in predict_post["requestBody"]["content"]