Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 175 additions & 12 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# LangSmith 트레이싱: langchain import 전에 os.environ 주입 필수
# (langchain SDK는 import 시점에 LANGCHAIN_TRACING_V2를 읽으므로 순서가 중요)
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -77,6 +77,7 @@
from src.schemas.simulation_input import SimulationInput
from src.services.auth import AuthService
from src.services.biz_mapper import BizMapper
from src.services.jwt_auth import UserContext, get_optional_user

from models.explainability.shap_analysis import explain_tcn_prediction
from models.explainability.simulation import (
Expand Down Expand Up @@ -225,6 +226,70 @@ def _pipeline_key(input_data: Any) -> str:
return f"{input_data.target_district}:{input_data.business_type}:{input_data.brand_name}:{rent}:{area}:{radius}:{pop_w}"


def _resolve_user_biz_number(user: UserContext | None) -> str | None:
"""JWT user → users.biz_number 조회. master 는 본인, manager 는 owner 의 biz_number."""
if user is None:
return None
target_id = user.owner_id if user.role == "manager" else user.user_id
if not target_id:
return None
try:
import sqlalchemy as sa

engine = sa.create_engine(settings.postgres_url)
with engine.connect() as conn:
row = conn.execute(
sa.text("SELECT biz_number FROM users WHERE id = :id"),
{"id": target_id},
).first()
return row._mapping["biz_number"] if row else None
except Exception as ex:
logger.warning(f"[brand_resolver] biz_number 조회 실패: {ex}")
return None


def _validate_and_resolve_brand(
input_data: SimulationInput,
current_user: UserContext | None = None,
) -> None:
"""biz_number 입력 시 corp 검증 + 다업종 corp 의 brand auto-resolve.

biz_number 우선순위:
1. ``input_data.biz_number`` (frontend 명시 입력)
2. JWT ``current_user`` 토큰에서 자동 추출 (master.user_id 또는 manager.owner_id)
3. 없으면 검증 skip (개인사업자 / 비회원 호환)

동작:
1. business_type 이 사용자 corp 의 운영 업종인지 검증.
2. 운영 외 업종 → HTTPException(400) + 운영 가능 업종 list 응답.
3. 운영 내 업종 + corp 의 해당 업종 brand 가 다른 brand 면 brand_name override.
"""
biz_number = input_data.biz_number or _resolve_user_biz_number(current_user)
if not biz_number:
return

from src.services.corp_brand_resolver import resolve_brand_for_industry

result = resolve_brand_for_industry(biz_number, input_data.business_type)

if result.get("error") == "INDUSTRY_NOT_OPERATED":
raise HTTPException(status_code=400, detail=result)

if result.get("error") in {"USER_NOT_FOUND", "CORP_NOT_IN_FTC", "INVALID_COMPANY_NAME"}:
# 비회원 / FTC 미등록 → 검증 skip, 사용자 brand_name 그대로
logger.warning(f"[brand_resolver] {result['error']} biz={biz_number} — fallback to input.brand_name")
return

# 성공: brand_name override (사용자가 다른 brand 입력했어도 corp 정합 brand 로 교체)
resolved_brand = result["brand_name"]
if input_data.brand_name != resolved_brand:
logger.info(
f"[brand_resolver] auto-resolve: input.brand_name='{input_data.brand_name}' → '{resolved_brand}' "
f"(corp={result['company_name']}, industry={input_data.business_type})"
)
input_data.brand_name = resolved_brand


_BIZ_TYPE_NORMALIZE: dict[str, str] = {
"cafe": "카페",
"coffee": "카페",
Expand Down Expand Up @@ -930,7 +995,11 @@ async def get_status(job_id: str):


@app.post("/analyze")
async def analyze_location(input_data: SimulationInput, response: Response):
async def analyze_location(
input_data: SimulationInput,
response: Response,
current_user: UserContext | None = Depends(get_optional_user),
):
"""[DEPRECATED] 풀파이프 상권 분석 — 전환 기간 동안만 유지.

IM3-259로 endpoint를 분리(/predict + /analyze/llm)했으므로 신규 호출은
Expand All @@ -943,6 +1012,9 @@ async def analyze_location(input_data: SimulationInput, response: Response):
response.headers["Deprecation"] = "true"
response.headers["Link"] = '</predict>; rel="successor-version", </analyze/llm>; rel="successor-version"'

# corp 다업종 brand auto-resolve + 운영 외 업종 차단 (biz_number 입력 시만)
_validate_and_resolve_brand(input_data, current_user)

if input_data.target_district not in MAPO_DISTRICTS:
return {
"status": "error",
Expand Down Expand Up @@ -970,7 +1042,9 @@ async def analyze_location(input_data: SimulationInput, response: Response):
result["all_competitor_locations"] = await _collect_all_competitor_locations(
winner, top3, input_data.business_type
)
result["same_brand_locations"] = await _collect_same_brand_locations(winner, top3, input_data.brand_name, input_data.business_type)
result["same_brand_locations"] = await _collect_same_brand_locations(
winner, top3, input_data.brand_name, input_data.business_type
)
return {"status": "success", "data": result}
except Exception as e:
print(f"!!! [API ERROR] !!! {str(e)}")
Expand All @@ -987,14 +1061,20 @@ async def analyze_location(input_data: SimulationInput, response: Response):


@app.post("/analyze/llm")
async def analyze_llm(input_data: SimulationInput):
async def analyze_llm(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""AI 분석 전용 endpoint — slow_graph 실행 (~80-140초).

/predict와 독립 병렬 호출 가능. winner는 ranking 단계에서 자체 결정.
"""
from src.config.constants import MAPO_DISTRICTS
from src.schemas.simulation_output import AnalysisOutput

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

if input_data.target_district not in MAPO_DISTRICTS:
return {
"status": "error",
Expand Down Expand Up @@ -1035,7 +1115,9 @@ async def analyze_llm(input_data: SimulationInput):
print(f"[ANALYZE/LLM] all_competitor_locations 수집 실패 (무시): {e}")
full["all_competitor_locations"] = []
try:
full["same_brand_locations"] = await _collect_same_brand_locations(winner, top3, input_data.brand_name, input_data.business_type)
full["same_brand_locations"] = await _collect_same_brand_locations(
winner, top3, input_data.brand_name, input_data.business_type
)
except Exception as e:
print(f"[ANALYZE/LLM] same_brand_locations 수집 실패 (무시): {e}")
full["same_brand_locations"] = []
Expand All @@ -1059,7 +1141,10 @@ async def analyze_llm(input_data: SimulationInput):


@app.post("/analyze/llm/async")
async def analyze_llm_async(input_data: SimulationInput) -> dict[str, Any]:
async def analyze_llm_async(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
) -> dict[str, Any]:
"""AI 분석 비동기 시작 — 즉시 job_id 반환. LangGraph 노드별 진행률 추적."""
from src.config.constants import MAPO_DISTRICTS
from src.schemas.simulation_output import AnalysisOutput
Expand All @@ -1070,6 +1155,9 @@ async def analyze_llm_async(input_data: SimulationInput) -> dict[str, Any]:
set_progress,
)

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

if input_data.target_district not in MAPO_DISTRICTS:
return {
"status": "error",
Expand Down Expand Up @@ -1128,7 +1216,9 @@ async def _run() -> None:
logger.warning(f"[/analyze/llm/async] all_competitor_locations 실패 (무시): {ce}")
full["all_competitor_locations"] = []
try:
full["same_brand_locations"] = await _collect_same_brand_locations(winner, top3, input_data.brand_name, input_data.business_type)
full["same_brand_locations"] = await _collect_same_brand_locations(
winner, top3, input_data.brand_name, input_data.business_type
)
except Exception as ce:
logger.warning(f"[/analyze/llm/async] same_brand_locations 실패 (무시): {ce}")
full["same_brand_locations"] = []
Expand All @@ -1137,6 +1227,13 @@ async def _run() -> None:
payload = {k: v for k, v in full.items() if k in analysis_keys}
payload["request_id"] = request_id
payload["target_district"] = full.get("target_district") or input_data.target_district
# DEBUG: payload 직전 same_brand_locations 검증 (frontend 측 누락 의심 시)
logger.info(
f"[/analyze/llm/async] payload check job={job_id[:8]} "
f"same_brand={len(payload.get('same_brand_locations', []) or [])} "
f"all_competitor={len(payload.get('all_competitor_locations', []) or [])} "
f"keys_count={len(payload)}"
)
set_done(job_id, _safe_json(payload))
logger.info(f"[/analyze/llm/async] 완료 job={job_id[:8]}")
except Exception as e:
Expand Down Expand Up @@ -1172,7 +1269,10 @@ async def analyze_llm_job_status(job_id: str) -> dict[str, Any]:


@app.post("/analyze/quick")
async def analyze_quick(input_data: SimulationInput):
async def analyze_quick(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""
LLM 없는 경량 랭킹 엔드포인트 (district_ranking 에이전트만 실행).

Expand All @@ -1184,6 +1284,9 @@ async def analyze_quick(input_data: SimulationInput):
from src.agents.nodes.district_ranking import district_ranking_node
from src.agents.nodes.market_analyst import db_client

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

normalized_biz = _BIZ_TYPE_NORMALIZE.get(input_data.business_type.lower(), input_data.business_type)

print(f"--- [API] /analyze/quick 요청: {input_data.target_district} / {normalized_biz} ---")
Expand Down Expand Up @@ -1225,6 +1328,45 @@ class BizLookupRequest(BaseModel):
company_name: str = ""


@app.get("/corp/operated-industries")
async def get_operated_industries(
biz_number: str | None = None,
current_user: UserContext | None = Depends(get_optional_user),
) -> dict:
"""사용자 corp 의 운영 업종/브랜드 list 반환.

Frontend 시뮬 입력 폼이 mount 시 호출 — dropdown 에서 운영 외 업종 disable 용.

biz_number 우선순위:
1. query param ``biz_number`` (frontend 명시)
2. JWT 토큰의 user.user_id → users.biz_number 자동 추출

Returns:
성공: ``{"company_name": str, "industries": [str, ...], "brands": [{name, industry, stores}, ...]}``
실패 (USER_NOT_FOUND/CORP_NOT_IN_FTC): ``{"industries": null, "error": ..., "company_name": ...}``
비회원 (biz_number 미입력 + 토큰 없음): ``{"industries": null}`` — 모든 업종 허용
"""
from src.services.corp_brand_resolver import get_corp_industries

biz = biz_number or _resolve_user_biz_number(current_user)
if not biz:
return {"industries": None, "company_name": None, "brands": []}

portfolio = get_corp_industries(biz)
if "error" in portfolio:
return {
"industries": None,
"error": portfolio["error"],
"company_name": portfolio.get("company_name"),
"message": portfolio.get("message"),
}
return {
"company_name": portfolio["company_name"],
"industries": portfolio["industries"],
"brands": portfolio["brands"],
}


@app.post("/biz/lookup")
async def biz_lookup(req: BizLookupRequest):
"""사업자등록번호 + 기업명으로 프랜차이즈 브랜드 매핑.
Expand Down Expand Up @@ -1659,7 +1801,10 @@ def _mock_simulation_response(target_district: str, request_id: str) -> dict:


@app.post("/predict")
async def predict_districts(input_data: SimulationInput):
async def predict_districts(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
):
"""
선택 동 1~4개 ML 예측 전용 엔드포인트 (LangGraph 미사용)

Expand All @@ -1669,6 +1814,9 @@ async def predict_districts(input_data: SimulationInput):
"""
from src.config.constants import MAPO_DISTRICTS

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

target_districts = getattr(input_data, "target_districts", None) or [input_data.target_district]
target_districts = [d for d in target_districts if d in MAPO_DISTRICTS][:4]

Expand Down Expand Up @@ -1728,7 +1876,10 @@ async def predict_districts(input_data: SimulationInput):
# 단계: 동별 _predict_single_district 가 끝날 때마다 progress = done/total.
# ---------------------------------------------------------------------------
@app.post("/predict/async")
async def predict_districts_async(input_data: SimulationInput) -> dict[str, Any]:
async def predict_districts_async(
input_data: SimulationInput,
current_user: UserContext | None = Depends(get_optional_user),
) -> dict[str, Any]:
"""ML 예측 비동기 시작 — 즉시 job_id 반환. 진행률은 status endpoint 폴링."""
from src.config.constants import MAPO_DISTRICTS
from src.services.job_progress_store import (
Expand All @@ -1738,6 +1889,9 @@ async def predict_districts_async(input_data: SimulationInput) -> dict[str, Any]
set_progress,
)

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

target_districts = getattr(input_data, "target_districts", None) or [input_data.target_district]
target_districts = [d for d in target_districts if d in MAPO_DISTRICTS][:4]
if not target_districts:
Expand Down Expand Up @@ -1843,13 +1997,20 @@ async def predict_job_status(job_id: str) -> dict[str, Any]:


@app.post("/simulate", deprecated=True)
async def run_simulation(input_data: SimulationInput, response: Response):
async def run_simulation(
input_data: SimulationInput,
response: Response,
current_user: UserContext | None = Depends(get_optional_user),
):
"""기본 시뮬레이션 엔드포인트"""
response.headers["Deprecation"] = "true"
response.headers["Link"] = '</predict>; rel="successor-version", </analyze/llm>; rel="successor-version"'

from src.config.constants import MAPO_DISTRICTS

# corp 다업종 brand auto-resolve + 운영 외 업종 차단
_validate_and_resolve_brand(input_data, current_user)

if input_data.target_district not in MAPO_DISTRICTS:
return {
"status": "error",
Expand All @@ -1872,7 +2033,9 @@ async def run_simulation(input_data: SimulationInput, response: Response):
winner = result.get("winner_district") or input_data.target_district
top3 = result.get("top_3_candidates") or []
try:
result["same_brand_locations"] = await _collect_same_brand_locations(winner, top3, input_data.brand_name, input_data.business_type)
result["same_brand_locations"] = await _collect_same_brand_locations(
winner, top3, input_data.brand_name, input_data.business_type
)
except Exception as ce:
logger.warning(f"[/simulate] same_brand_locations 실패 (무시): {ce}")
result["same_brand_locations"] = []
Expand Down
7 changes: 7 additions & 0 deletions backend/src/schemas/simulation_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ class SimulationInput(BaseModel):
target_day_type: str | None = Field(default=None, description="타겟 요일: 'weekday' | 'weekend' | None(전체)")
target_monthly_sales: int | None = Field(default=None, description="예상 월매출 (원). None=비율만 계산, 금액 제외")

# [corp_brand_resolver] biz_number 검증 트리거.
# frontend 가 보내거나 main.py 에서 JWT 토큰의 user.user_id → users.biz_number 자동 추출.
# corp 검증: 해당 biz_number 가 운영하는 brand+업종 list 매핑.
biz_number: str | None = Field(
default=None, description="사업자등록번호 (corp 다업종 검증 트리거 — 미입력 시 검증 skip)"
)

@field_validator("business_type")
@classmethod
def _warn_unknown_business_type(cls, v: str) -> str:
Expand Down
Loading
Loading