diff --git a/backend/src/main.py b/backend/src/main.py index e0a06478..15046217 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -47,7 +47,7 @@ # LangSmith 트레이싱: langchain import 전에 os.environ 주입 필수 # (langchain SDK는 import 시점에 LANGCHAIN_TRACING_V2를 읽으므로 순서가 중요) from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -78,6 +78,7 @@ from src.services.auth import AuthService from src.services.biz_mapper import BizMapper from src.services.corp_brand_resolver import resolve_brand_for_industry +from src.services.jwt_auth import UserContext, get_optional_user from models.explainability.shap_analysis import explain_tcn_prediction from models.explainability.simulation import ( @@ -226,28 +227,56 @@ def _pipeline_key(input_data: Any) -> str: return f"{input_data.target_district}:{input_data.business_type}:{input_data.brand_name}:{rent}:{area}:{radius}:{pop_w}" -def _validate_and_resolve_brand(input_data: SimulationInput) -> None: +def _resolve_user_biz_number(user: UserContext | None) -> str | None: + """JWT user → users.biz_number 조회. master 는 본인, manager 는 owner 의 biz_number.""" + if user is None: + return None + target_id = user.owner_id if user.role == "manager" else user.user_id + if not target_id: + return None + try: + import sqlalchemy as sa + + engine = sa.create_engine(settings.postgres_url) + with engine.connect() as conn: + row = conn.execute( + sa.text("SELECT biz_number FROM users WHERE id = :id"), + {"id": target_id}, + ).first() + return row._mapping["biz_number"] if row else None + except Exception as ex: + logger.warning(f"[brand_resolver] biz_number 조회 실패: {ex}") + return None + + +def _validate_and_resolve_brand( + input_data: SimulationInput, + current_user: UserContext | None = None, +) -> None: """biz_number 입력 시 corp 검증 + 다업종 corp 의 brand auto-resolve. - 동작 (input_data.biz_number 가 입력됐을 때만): + biz_number 우선순위: + 1. ``input_data.biz_number`` (frontend 명시 입력) + 2. JWT ``current_user`` 토큰에서 자동 추출 (master.user_id 또는 manager.owner_id) + 3. 없으면 검증 skip (개인사업자 / 비회원 호환) + + 동작: 1. business_type 이 사용자 corp 의 운영 업종인지 검증. 2. 운영 외 업종 → HTTPException(400) + 운영 가능 업종 list 응답. 3. 운영 내 업종 + corp 의 해당 업종 brand 가 다른 brand 면 brand_name override. - - biz_number 미입력 (개인사업자 / 비회원) → 검증 skip, 사용자 brand_name 그대로. - FTC 미등록 corp → 검증 skip + 경고 로그. """ - if not input_data.biz_number: + biz_number = input_data.biz_number or _resolve_user_biz_number(current_user) + if not biz_number: return - result = resolve_brand_for_industry(input_data.biz_number, input_data.business_type) + result = resolve_brand_for_industry(biz_number, input_data.business_type) if result.get("error") == "INDUSTRY_NOT_OPERATED": raise HTTPException(status_code=400, detail=result) if result.get("error") in {"USER_NOT_FOUND", "CORP_NOT_IN_FTC", "INVALID_COMPANY_NAME"}: # 비회원 / FTC 미등록 → 검증 skip, 사용자 brand_name 그대로 - logger.warning(f"[brand_resolver] {result['error']} biz={input_data.biz_number} — fallback to input.brand_name") + logger.warning(f"[brand_resolver] {result['error']} biz={biz_number} — fallback to input.brand_name") return # 성공: brand_name override (사용자가 다른 brand 입력했어도 corp 정합 brand 로 교체) @@ -963,14 +992,18 @@ async def get_status(job_id: str): @app.post("/analyze") -async def analyze_location(input_data: SimulationInput, response: Response): +async def analyze_location( + input_data: SimulationInput, + response: Response, + current_user: UserContext | None = Depends(get_optional_user), +): """[DEPRECATED] 풀파이프 상권 분석 — 전환 기간 동안만 유지. IM3-259로 endpoint를 분리(/predict + /analyze/llm)했으므로 신규 호출은 그쪽으로 옮길 것. 이 endpoint는 기존 프론트/테스트 호환을 위해 유지하다가 충분히 검증되면 제거 예정. """ - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.config.constants import MAPO_DISTRICTS # IM3-259: deprecation 헤더 — 클라이언트가 /predict + /analyze/llm 으로 옮길 것을 알림 @@ -1021,12 +1054,15 @@ async def analyze_location(input_data: SimulationInput, response: Response): @app.post("/analyze/llm") -async def analyze_llm(input_data: SimulationInput): +async def analyze_llm( + input_data: SimulationInput, + current_user: UserContext | None = Depends(get_optional_user), +): """AI 분석 전용 endpoint — slow_graph 실행 (~80-140초). /predict와 독립 병렬 호출 가능. winner는 ranking 단계에서 자체 결정. """ - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.config.constants import MAPO_DISTRICTS from src.schemas.simulation_output import AnalysisOutput @@ -1094,9 +1130,12 @@ async def analyze_llm(input_data: SimulationInput): @app.post("/analyze/llm/async") -async def analyze_llm_async(input_data: SimulationInput) -> dict[str, Any]: +async def analyze_llm_async( + input_data: SimulationInput, + current_user: UserContext | None = Depends(get_optional_user), +) -> dict[str, Any]: """AI 분석 비동기 시작 — 즉시 job_id 반환. LangGraph 노드별 진행률 추적.""" - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.config.constants import MAPO_DISTRICTS from src.schemas.simulation_output import AnalysisOutput from src.services.job_progress_store import ( @@ -1208,7 +1247,10 @@ async def analyze_llm_job_status(job_id: str) -> dict[str, Any]: @app.post("/analyze/quick") -async def analyze_quick(input_data: SimulationInput): +async def analyze_quick( + input_data: SimulationInput, + current_user: UserContext | None = Depends(get_optional_user), +): """ LLM 없는 경량 랭킹 엔드포인트 (district_ranking 에이전트만 실행). @@ -1217,7 +1259,7 @@ async def analyze_quick(input_data: SimulationInput): 응답: { district_rankings, winner_district, top_3_candidates } """ - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.agents.nodes.district_ranking import district_ranking_node from src.agents.nodes.market_analyst import db_client @@ -1692,7 +1734,10 @@ def _mock_simulation_response(target_district: str, request_id: str) -> dict: @app.post("/predict") -async def predict_districts(input_data: SimulationInput): +async def predict_districts( + input_data: SimulationInput, + current_user: UserContext | None = Depends(get_optional_user), +): """ 선택 동 1~4개 ML 예측 전용 엔드포인트 (LangGraph 미사용) @@ -1700,7 +1745,7 @@ async def predict_districts(input_data: SimulationInput): - target_districts 전체에 대해 TCN/BEP/폐업률/폐업위험도/SHAP 병렬 실행 - 응답: 동별 예측 결과 리스트 (프론트 멀티라인 차트용) """ - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.config.constants import MAPO_DISTRICTS target_districts = getattr(input_data, "target_districts", None) or [input_data.target_district] @@ -1762,9 +1807,12 @@ async def predict_districts(input_data: SimulationInput): # 단계: 동별 _predict_single_district 가 끝날 때마다 progress = done/total. # --------------------------------------------------------------------------- @app.post("/predict/async") -async def predict_districts_async(input_data: SimulationInput) -> dict[str, Any]: +async def predict_districts_async( + input_data: SimulationInput, + current_user: UserContext | None = Depends(get_optional_user), +) -> dict[str, Any]: """ML 예측 비동기 시작 — 즉시 job_id 반환. 진행률은 status endpoint 폴링.""" - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) from src.config.constants import MAPO_DISTRICTS from src.services.job_progress_store import ( create_job, @@ -1878,9 +1926,13 @@ async def predict_job_status(job_id: str) -> dict[str, Any]: @app.post("/simulate", deprecated=True) -async def run_simulation(input_data: SimulationInput, response: Response): +async def run_simulation( + input_data: SimulationInput, + response: Response, + current_user: UserContext | None = Depends(get_optional_user), +): """기본 시뮬레이션 엔드포인트""" - _validate_and_resolve_brand(input_data) + _validate_and_resolve_brand(input_data, current_user) response.headers["Deprecation"] = "true" response.headers["Link"] = '; rel="successor-version", ; rel="successor-version"'