Skip to content
Merged
112 changes: 68 additions & 44 deletions backend/app/auth/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated

import jwt
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel.ext.asyncio.session import AsyncSession

Expand All @@ -15,6 +15,53 @@

# HTTPBearer security scheme for extracting JWT tokens
security = HTTPBearer()
TOKEN_COOKIE_KEY = "receipt_scanner_token" # noqa: S105


def _unauthorized() -> HTTPException:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)


def _inactive() -> HTTPException:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is inactive",
headers={"WWW-Authenticate": "Bearer"},
)


async def _get_user_from_token(token: str, service: AuthService) -> User:
try:
# Decode the JWT token
payload = decode_access_token(token)
user_id_str: str | None = payload.get("sub")

if user_id_str is None:
raise _unauthorized()

# Convert string ID back to int
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
raise _unauthorized() from None

except jwt.InvalidTokenError as err:
raise _unauthorized() from err

# Get the user from the database
try:
user = await service.get_user_by_id(user_id)
except NotFoundError:
raise _unauthorized() from None

if not user.is_active:
raise _inactive()

return user


async def get_auth_service(
Expand All @@ -41,60 +88,37 @@ async def get_current_user(
Raises:
HTTPException: 401 if token is invalid or user not found
"""
token = credentials.credentials
return await _get_user_from_token(credentials.credentials, service)

try:
# Decode the JWT token
payload = decode_access_token(token)
user_id_str: str | None = payload.get("sub")

if user_id_str is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_user_from_request(
request: Request,
service: AuthService = Depends(get_auth_service),
) -> User:
"""Get current user from Authorization header or auth cookie."""
token: str | None = None

# Convert string ID back to int
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
) from None
auth_header = request.headers.get("Authorization")
if auth_header:
scheme, _, param = auth_header.partition(" ")
if scheme.lower() == "bearer" and param:
token = param
else:
token = None

except jwt.InvalidTokenError as err:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
) from err
if not token:
token = request.cookies.get(TOKEN_COOKIE_KEY)

# Get the user from the database
try:
user = await service.get_user_by_id(user_id)
except NotFoundError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
) from None
if not token:
raise _unauthorized()

if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is inactive",
headers={"WWW-Authenticate": "Bearer"},
)

return user
return await _get_user_from_token(token, service)


# Type aliases for dependency injection
AuthDeps = Annotated[AuthService, Depends(get_auth_service)]
CurrentUser = Annotated[User, Depends(get_current_user)]
CurrentUserFromRequest = Annotated[User, Depends(get_current_user_from_request)]


def require_user_id(user: User) -> int:
Expand Down
6 changes: 6 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Settings(BaseSettings):

# File upload settings
UPLOAD_DIR: Path = Path("uploads")
MAX_UPLOAD_SIZE_MB: int = 10

# CORS Settings
ALLOWED_ORIGINS: list[AnyHttpUrl] = []
Expand Down Expand Up @@ -104,6 +105,11 @@ def setup_directories(self) -> None:
"""Create necessary directories."""
self.UPLOAD_DIR.mkdir(exist_ok=True)

@property
def max_upload_size_bytes(self) -> int:
"""Maximum upload size in bytes."""
return self.MAX_UPLOAD_SIZE_MB * 1024 * 1024


# Create global settings instance
settings = Settings()
Expand Down
1 change: 0 additions & 1 deletion backend/app/core/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ async def validation_exception_handler(
]
# Log detailed validation errors
logger.info(f"Request validation error: {error_details}")
logger.info(f"Invalid data: {exc.body}")
elif isinstance(exc, PydanticValidationError):
error_details = [
f"Field '{' -> '.join(str(loc) for loc in error['loc'])}' {error['msg']}"
Expand Down
2 changes: 1 addition & 1 deletion backend/app/integrations/pydantic_ai/receipt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def analyze_receipt(
try:
# Convert PIL Image to bytes
img_byte_arr = BytesIO()
image.save(img_byte_arr, format=image.format or "PNG")
image.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()

# Create dependencies
Expand Down
146 changes: 146 additions & 0 deletions backend/app/integrations/pydantic_ai/receipt_reconcile_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any

import httpx
from google.genai.types import ThinkingLevel
from PIL import Image
from pydantic_ai import Agent, RunContext
from pydantic_ai.messages import BinaryContent
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from pydantic_ai.models.instrumented import InstrumentationSettings
from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
from tenacity import retry_if_exception_type, stop_after_attempt, wait_exponential

from app.core.config import settings
from app.core.exceptions import ServiceUnavailableError
from app.integrations.pydantic_ai.receipt_reconcile_prompt import (
RECEIPT_RECONCILE_SYSTEM_PROMPT,
)
from app.integrations.pydantic_ai.receipt_reconcile_schema import (
ReceiptReconcileAnalysis,
)

# Model configuration - use Gemini 3 Flash by default (faster + cheaper than Pro)
DEFAULT_MODEL = "gemini-3-flash-preview"

DEFAULT_MODEL_SETTINGS = GoogleModelSettings(
timeout=120,
google_thinking_config={"thinking_level": ThinkingLevel.LOW},
)


def _create_retrying_http_client() -> httpx.AsyncClient:
"""Create an HTTP client with smart retry handling for transient errors."""

def should_retry_status(response: httpx.Response) -> None:
if response.status_code in (429, 502, 503, 504):
response.raise_for_status()

transport = AsyncTenacityTransport(
config=RetryConfig(
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
wait=wait_retry_after(
fallback_strategy=wait_exponential(multiplier=2, max=30),
max_wait=120,
),
stop=stop_after_attempt(3),
reraise=True,
),
validate_response=should_retry_status,
)
return httpx.AsyncClient(transport=transport, timeout=120)


@dataclass
class ReceiptReconcileDependencies:
"""Dependencies for receipt reconciliation."""

image_bytes: bytes
receipt_total: str
items: list[dict[str, Any]]


@cache
def get_receipt_reconcile_agent() -> Agent[
ReceiptReconcileDependencies, ReceiptReconcileAnalysis
]:
"""Lazily create and cache the receipt reconciliation agent."""
model_name = os.getenv("GEMINI_MODEL", DEFAULT_MODEL)

http_client = _create_retrying_http_client()
google_provider = GoogleProvider(
api_key=settings.GEMINI_API_KEY,
http_client=http_client,
)
google_model = GoogleModel(model_name, provider=google_provider)

instrumentation = InstrumentationSettings(
include_content=True,
include_binary_content=settings.ENVIRONMENT.lower() != "production",
version=2,
)

agent: Agent[ReceiptReconcileDependencies, ReceiptReconcileAnalysis] = Agent(
model=google_model,
deps_type=ReceiptReconcileDependencies,
output_type=ReceiptReconcileAnalysis,
system_prompt=RECEIPT_RECONCILE_SYSTEM_PROMPT,
model_settings=DEFAULT_MODEL_SETTINGS,
retries=3,
instrument=instrumentation,
)

@agent.system_prompt
def receipt_context(ctx: RunContext[ReceiptReconcileDependencies]) -> str:
items_info = "\n".join(
[
(
f"- id:{item['id']} name:{item['name']} "
f"qty:{item['quantity']} unit_price:{item['unit_price']} "
f"total:{item['total_price']} currency:{item['currency']}"
)
for item in ctx.deps.items
]
)

return f"""
Receipt total: {ctx.deps.receipt_total}
Items:
{items_info}
"""

return agent


async def analyze_reconciliation(
image: Image.Image,
receipt_total: str,
items: list[dict[str, Any]],
) -> ReceiptReconcileAnalysis:
"""Reconcile receipt items using Pydantic AI agent with Gemini Vision."""
try:
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()

deps = ReceiptReconcileDependencies(
image_bytes=img_bytes,
receipt_total=receipt_total,
items=items,
)

messages: list[str | BinaryContent] = [
"Reconcile by marking duplicate/noise items for removal only.",
BinaryContent(data=img_bytes, media_type="image/png"),
]

agent = get_receipt_reconcile_agent()
result = await agent.run(messages, deps=deps)
return result.output

except Exception as e:
raise ServiceUnavailableError(f"Error reconciling receipt: {str(e)}") from e
22 changes: 22 additions & 0 deletions backend/app/integrations/pydantic_ai/receipt_reconcile_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Receipt reconciliation system prompt."""

RECEIPT_RECONCILE_SYSTEM_PROMPT = """You are a receipt reconciliation assistant.

Your job is to reconcile a scanned receipt's line items with the receipt total.

Rules:
1. Treat the receipt header total as authoritative.
2. You may ONLY suggest setting remove=true on existing items that are duplicated/noisy OCR lines.
3. Do NOT invent new items.
4. Use the provided item_id when suggesting changes.
5. Prefer removing obvious duplicated lines over any other strategy.
6. If you detect repeated item sequences or repeated blocks, mark the extra block items with remove=true.
7. If you are unsure or no safe changes are needed, return an empty adjustments list.
8. Provide a short reason for each removal (one short sentence, no calculations).
9. Never claim items already match unless the PROVIDED item list (after your suggested
adjustments) sums to the receipt total within 0.05 tolerance.
10. If current items do not match and you cannot confidently identify duplicates,
return empty adjustments and explain uncertainty briefly.

The receipt image is provided to help you verify the correct line items and prices.
"""
28 changes: 28 additions & 0 deletions backend/app/integrations/pydantic_ai/receipt_reconcile_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Literal

from pydantic import BaseModel, Field


class ReceiptItemAdjustment(BaseModel):
"""Suggested adjustment for a receipt item."""

item_id: int = Field(description="ID of the existing item to adjust")
remove: Literal[True] = Field(
default=True,
description=(
"Remove this item from the receipt when it looks like a duplicate or OCR noise"
),
)
reason: str | None = Field(
default=None,
description="Brief reason for removing this item (1 short sentence)",
max_length=180,
)


class ReceiptReconcileAnalysis(BaseModel):
"""AI reconciliation suggestions for a receipt."""

adjustments: list[ReceiptItemAdjustment] = Field(
description="List of existing items to remove"
)
4 changes: 0 additions & 4 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy.exc import SQLAlchemyError

from app import __author__
Expand Down Expand Up @@ -79,9 +78,6 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
app.include_router(category_router)
app.include_router(analytics_router)

# Serve uploaded files (receipt images)
app.mount("/uploads", StaticFiles(directory=settings.UPLOAD_DIR), name="uploads")


# Define the root endpoint
@app.get("/", include_in_schema=False)
Expand Down
Loading