Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 75 additions & 23 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# LangSmith 트레이싱: langchain import 전에 os.environ 주입 필수
# (langchain SDK는 import 시점에 LANGCHAIN_TRACING_V2를 읽으므로 순서가 중요)
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -78,6 +78,7 @@
from src.services.auth import AuthService
from src.services.biz_mapper import BizMapper
from src.services.corp_brand_resolver import resolve_brand_for_industry
from src.services.jwt_auth import UserContext, get_optional_user

from models.explainability.shap_analysis import explain_tcn_prediction
from models.explainability.simulation import (
Expand Down Expand Up @@ -226,28 +227,56 @@ def _pipeline_key(input_data: Any) -> str:
return f"{input_data.target_district}:{input_data.business_type}:{input_data.brand_name}:{rent}:{area}:{radius}:{pop_w}"


def _validate_and_resolve_brand(input_data: SimulationInput) -> None:
def _resolve_user_biz_number(user: UserContext | None) -> str | None:
"""JWT user → users.biz_number 조회. master 는 본인, manager 는 owner 의 biz_number."""
if user is None:
return None
target_id = user.owner_id if user.role == "manager" else user.user_id
if not target_id:
return None
try:
import sqlalchemy as sa

engine = sa.create_engine(settings.postgres_url)
with engine.connect() as conn:
row = conn.execute(
sa.text("SELECT biz_number FROM users WHERE id = :id"),
{"id": target_id},
).first()
return row._mapping["biz_number"] if row else None
except Exception as ex:
logger.warning(f"[brand_resolver] biz_number 조회 실패: {ex}")
return None


def _validate_and_resolve_brand(
input_data: SimulationInput,
current_user: UserContext | None = None,
) -> None:
"""biz_number 입력 시 corp 검증 + 다업종 corp 의 brand auto-resolve.

동작 (input_data.biz_number 가 입력됐을 때만):
biz_number 우선순위:
1. ``input_data.biz_number`` (frontend 명시 입력)
2. JWT ``current_user`` 토큰에서 자동 추출 (master.user_id 또는 manager.owner_id)
3. 없으면 검증 skip (개인사업자 / 비회원 호환)

동작:
1. business_type 이 사용자 corp 의 운영 업종인지 검증.
2. 운영 외 업종 → HTTPException(400) + 운영 가능 업종 list 응답.
3. 운영 내 업종 + corp 의 해당 업종 brand 가 다른 brand 면 brand_name override.

biz_number 미입력 (개인사업자 / 비회원) → 검증 skip, 사용자 brand_name 그대로.
FTC 미등록 corp → 검증 skip + 경고 로그.
"""
if not input_data.biz_number:
biz_number = input_data.biz_number or _resolve_user_biz_number(current_user)
if not biz_number:
return

result = resolve_brand_for_industry(input_data.biz_number, input_data.business_type)
result = resolve_brand_for_industry(biz_number, input_data.business_type)

if result.get("error") == "INDUSTRY_NOT_OPERATED":
raise HTTPException(status_code=400, detail=result)

if result.get("error") in {"USER_NOT_FOUND", "CORP_NOT_IN_FTC", "INVALID_COMPANY_NAME"}:
# 비회원 / FTC 미등록 → 검증 skip, 사용자 brand_name 그대로
logger.warning(f"[brand_resolver] {result['error']} biz={input_data.biz_number} — fallback to input.brand_name")
logger.warning(f"[brand_resolver] {result['error']} biz={biz_number} — fallback to input.brand_name")
return

# 성공: brand_name override (사용자가 다른 brand 입력했어도 corp 정합 brand 로 교체)
Expand Down Expand Up @@ -963,14 +992,18 @@ async def get_status(job_id: str):


@app.post("/analyze")
async def analyze_location(input_data: SimulationInput, response: Response):
async def analyze_location(
input_data: SimulationInput,
response: Response,
current_user: UserContext | None = Depends(get_optional_user),
):
"""[DEPRECATED] 풀파이프 상권 분석 — 전환 기간 동안만 유지.

IM3-259로 endpoint를 분리(/predict + /analyze/llm)했으므로 신규 호출은
그쪽으로 옮길 것. 이 endpoint는 기존 프론트/테스트 호환을 위해 유지하다가
충분히 검증되면 제거 예정.
"""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.config.constants import MAPO_DISTRICTS

# IM3-259: deprecation 헤더 — 클라이언트가 /predict + /analyze/llm 으로 옮길 것을 알림
Expand Down Expand Up @@ -1021,12 +1054,15 @@ async def analyze_location(input_data: SimulationInput, response: Response):


@app.post("/analyze/llm")
async def analyze_llm(input_data: SimulationInput):
async def analyze_llm(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""AI 분석 전용 endpoint — slow_graph 실행 (~80-140초).

/predict와 독립 병렬 호출 가능. winner는 ranking 단계에서 자체 결정.
"""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.config.constants import MAPO_DISTRICTS
from src.schemas.simulation_output import AnalysisOutput

Expand Down Expand Up @@ -1094,9 +1130,12 @@ async def analyze_llm(input_data: SimulationInput):


@app.post("/analyze/llm/async")
async def analyze_llm_async(input_data: SimulationInput) -> dict[str, Any]:
async def analyze_llm_async(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
) -> dict[str, Any]:
"""AI 분석 비동기 시작 — 즉시 job_id 반환. LangGraph 노드별 진행률 추적."""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.config.constants import MAPO_DISTRICTS
from src.schemas.simulation_output import AnalysisOutput
from src.services.job_progress_store import (
Expand Down Expand Up @@ -1208,7 +1247,10 @@ async def analyze_llm_job_status(job_id: str) -> dict[str, Any]:


@app.post("/analyze/quick")
async def analyze_quick(input_data: SimulationInput):
async def analyze_quick(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""
LLM 없는 경량 랭킹 엔드포인트 (district_ranking 에이전트만 실행).

Expand All @@ -1217,7 +1259,7 @@ async def analyze_quick(input_data: SimulationInput):

응답: { district_rankings, winner_district, top_3_candidates }
"""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.agents.nodes.district_ranking import district_ranking_node
from src.agents.nodes.market_analyst import db_client

Expand Down Expand Up @@ -1692,15 +1734,18 @@ def _mock_simulation_response(target_district: str, request_id: str) -> dict:


@app.post("/predict")
async def predict_districts(input_data: SimulationInput):
async def predict_districts(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""
선택 동 1~4개 ML 예측 전용 엔드포인트 (LangGraph 미사용)

- district_ranking, winner 로직 없음
- target_districts 전체에 대해 TCN/BEP/폐업률/폐업위험도/SHAP 병렬 실행
- 응답: 동별 예측 결과 리스트 (프론트 멀티라인 차트용)
"""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.config.constants import MAPO_DISTRICTS

target_districts = getattr(input_data, "target_districts", None) or [input_data.target_district]
Expand Down Expand Up @@ -1762,9 +1807,12 @@ async def predict_districts(input_data: SimulationInput):
# 단계: 동별 _predict_single_district 가 끝날 때마다 progress = done/total.
# ---------------------------------------------------------------------------
@app.post("/predict/async")
async def predict_districts_async(input_data: SimulationInput) -> dict[str, Any]:
async def predict_districts_async(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
) -> dict[str, Any]:
"""ML 예측 비동기 시작 — 즉시 job_id 반환. 진행률은 status endpoint 폴링."""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
from src.config.constants import MAPO_DISTRICTS
from src.services.job_progress_store import (
create_job,
Expand Down Expand Up @@ -1878,9 +1926,13 @@ async def predict_job_status(job_id: str) -> dict[str, Any]:


@app.post("/simulate", deprecated=True)
async def run_simulation(input_data: SimulationInput, response: Response):
async def run_simulation(
input_data: SimulationInput,
response: Response,
current_user: UserContext | None = Depends(get_optional_user),
):
"""기본 시뮬레이션 엔드포인트"""
_validate_and_resolve_brand(input_data)
_validate_and_resolve_brand(input_data, current_user)
response.headers["Deprecation"] = "true"
response.headers["Link"] = '</predict>; rel="successor-version", </analyze/llm>; rel="successor-version"'

Expand Down
Loading