From d40f7109f2be486b5e36d80deb62f6d57e70a2b4 Mon Sep 17 00:00:00 2001 From: Candice0313 Date: Fri, 22 May 2026 13:56:26 -0500 Subject: [PATCH 1/2] migrate db layer from sqlite3 to SQLAlchemy (DEV/PROD switchable) --- web/app/db.py | 238 ++++++++++++++++----------------- web/app/initdb_mysql.sql | 21 +++ web/app/initdb_sqlite.sql | 21 +++ web/data/app.db | Bin 131072 -> 131072 bytes web/requirements-min.txt | 2 + web/tests/test_opm.py | 51 +++---- web/tests/test_opm_validate.py | 54 ++++---- 7 files changed, 213 insertions(+), 174 deletions(-) create mode 100644 web/app/initdb_mysql.sql create mode 100644 web/app/initdb_sqlite.sql diff --git a/web/app/db.py b/web/app/db.py index 5c1b141..d119951 100644 --- a/web/app/db.py +++ b/web/app/db.py @@ -1,180 +1,168 @@ from __future__ import annotations +import json as _json import os -import sqlite3 from pathlib import Path from typing import Optional +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.engine import URL + +load_dotenv() # loads .env; system env vars override .env values BASE_DIR = Path(__file__).resolve().parents[1] _data_override = os.environ.get("WEB_DATA_DIR", "").strip() DATA_DIR = Path(_data_override) if _data_override else BASE_DIR / "data" DB_PATH = DATA_DIR / "app.db" - -def ensure_data_directory_exists() -> None: +INITDB_SQLITE = Path(__file__).parent / "initdb_sqlite.sql" +INITDB_MYSQL = Path(__file__).parent / "initdb_mysql.sql" + +DEPLOY_TYPE = os.environ.get("DEPLOY_TYPE", "DEV").strip().upper() + +if DEPLOY_TYPE == "PROD": + url = URL.create( + drivername="mysql+pymysql", + username=os.environ.get("MYSQL_USER", "internta_user"), + password=os.environ["MYSQL_PASSWORD"], + host=os.environ.get("MYSQL_HOST", "mysql6.sqlpub.com"), + port=int(os.environ.get("MYSQL_PORT", "3311")), + database=os.environ.get("MYSQL_DATABASE", "internta_db"), + ) +else: DATA_DIR.mkdir(parents=True, exist_ok=True) + url = URL.create( + drivername="sqlite+pysqlite", + database=str(DB_PATH), + ) - -def get_connection() -> sqlite3.Connection: - ensure_data_directory_exists() - connection = sqlite3.connect(DB_PATH) - connection.row_factory = sqlite3.Row - return connection +engine = create_engine(url) def init_db() -> None: - ensure_data_directory_exists() - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS notes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - content TEXT NOT NULL, - created_at TEXT DEFAULT (datetime('now')) - ); - """ - ) - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS action_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - note_id INTEGER, - text TEXT NOT NULL, - done INTEGER DEFAULT 0, - created_at TEXT DEFAULT (datetime('now')), - FOREIGN KEY (note_id) REFERENCES notes(id) - ); - """ - ) - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS opm_diagrams ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - note_id INTEGER, - payload TEXT NOT NULL, - created_at TEXT DEFAULT (datetime('now')) - ); - """ - ) - connection.commit() + sql_file = INITDB_MYSQL if DEPLOY_TYPE == "PROD" else INITDB_SQLITE + sql = Path(sql_file).read_text() + statements = [s.strip() for s in sql.split(";") if s.strip()] + with engine.begin() as conn: + for stmt in statements: + conn.execute(text(stmt)) + + +def _last_insert_id(result) -> int: + return int(result.lastrowid) def insert_note(content: str) -> int: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute("INSERT INTO notes (content) VALUES (?)", (content,)) - connection.commit() - return int(cursor.lastrowid) - - -def list_notes() -> list[sqlite3.Row]: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute("SELECT id, content, created_at FROM notes ORDER BY id DESC") - return list(cursor.fetchall()) - - -def get_note(note_id: int) -> Optional[sqlite3.Row]: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute( - "SELECT id, content, created_at FROM notes WHERE id = ?", - (note_id,), + with engine.begin() as conn: + result = conn.execute( + text("INSERT INTO notes (content) VALUES (:content)"), + {"content": content}, ) - row = cursor.fetchone() - return row + return _last_insert_id(result) + + +def list_notes() -> list: + with engine.connect() as conn: + return list( + conn.execute( + text("SELECT id, content, created_at FROM notes ORDER BY id DESC") + ).mappings().fetchall() + ) + + +def get_note(note_id: int) -> Optional[object]: + with engine.connect() as conn: + return conn.execute( + text("SELECT id, content, created_at FROM notes WHERE id = :id"), + {"id": note_id}, + ).mappings().fetchone() def insert_action_items(items: list[str], note_id: Optional[int] = None) -> list[int]: - with get_connection() as connection: - cursor = connection.cursor() - ids: list[int] = [] + ids: list[int] = [] + with engine.begin() as conn: for item in items: - cursor.execute( - "INSERT INTO action_items (note_id, text) VALUES (?, ?)", - (note_id, item), + result = conn.execute( + text("INSERT INTO action_items (note_id, text) VALUES (:note_id, :text)"), + {"note_id": note_id, "text": item}, ) - ids.append(int(cursor.lastrowid)) - connection.commit() - return ids + ids.append(_last_insert_id(result)) + return ids -def list_action_items(note_id: Optional[int] = None) -> list[sqlite3.Row]: - with get_connection() as connection: - cursor = connection.cursor() +def list_action_items(note_id: Optional[int] = None) -> list: + with engine.connect() as conn: if note_id is None: - cursor.execute( - "SELECT id, note_id, text, done, created_at FROM action_items ORDER BY id DESC" - ) + rows = conn.execute( + text( + "SELECT id, note_id, text, done, created_at" + " FROM action_items ORDER BY id DESC" + ) + ).mappings().fetchall() else: - cursor.execute( - "SELECT id, note_id, text, done, created_at FROM action_items WHERE note_id = ? ORDER BY id DESC", - (note_id,), - ) - return list(cursor.fetchall()) + rows = conn.execute( + text( + "SELECT id, note_id, text, done, created_at" + " FROM action_items WHERE note_id = :note_id ORDER BY id DESC" + ), + {"note_id": note_id}, + ).mappings().fetchall() + return list(rows) def mark_action_item_done(action_item_id: int, done: bool) -> None: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute( - "UPDATE action_items SET done = ? WHERE id = ?", - (1 if done else 0, action_item_id), + with engine.begin() as conn: + conn.execute( + text("UPDATE action_items SET done = :done WHERE id = :id"), + {"done": 1 if done else 0, "id": action_item_id}, ) - connection.commit() - - -import json as _json def insert_opm_diagram(payload: dict, note_id: Optional[int] = None) -> int: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute( - "INSERT INTO opm_diagrams (note_id, payload) VALUES (?, ?)", - (note_id, _json.dumps(payload)), + with engine.begin() as conn: + result = conn.execute( + text( + "INSERT INTO opm_diagrams (note_id, payload) VALUES (:note_id, :payload)" + ), + {"note_id": note_id, "payload": _json.dumps(payload)}, ) - connection.commit() - return int(cursor.lastrowid) + return _last_insert_id(result) def list_opm_diagrams(limit: Optional[int] = None) -> list[dict]: q = "SELECT id, note_id, payload, created_at FROM opm_diagrams ORDER BY id DESC" - params: tuple = () + params: dict = {} if limit is not None: - q += " LIMIT ?" - params = (int(limit),) - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute(q, params) - rows = cursor.fetchall() + q += " LIMIT :limit" + params["limit"] = int(limit) + with engine.connect() as conn: + rows = conn.execute(text(q), params).fetchall() return [ { - "id": r["id"], - "note_id": r["note_id"], - "created_at": r["created_at"], - "diagram": _json.loads(r["payload"]), + "id": r.id, + "note_id": r.note_id, + "created_at": r.created_at, + "diagram": _json.loads(r.payload), } for r in rows ] def get_opm_diagram(diagram_id: int) -> Optional[dict]: - with get_connection() as connection: - cursor = connection.cursor() - cursor.execute( - "SELECT id, note_id, payload, created_at FROM opm_diagrams WHERE id = ?", - (diagram_id,), - ) - row = cursor.fetchone() + with engine.connect() as conn: + row = conn.execute( + text( + "SELECT id, note_id, payload, created_at" + " FROM opm_diagrams WHERE id = :id" + ), + {"id": diagram_id}, + ).fetchone() if row is None: return None return { - "id": row["id"], - "note_id": row["note_id"], - "created_at": row["created_at"], - "diagram": _json.loads(row["payload"]), + "id": row.id, + "note_id": row.note_id, + "created_at": row.created_at, + "diagram": _json.loads(row.payload), } - diff --git a/web/app/initdb_mysql.sql b/web/app/initdb_mysql.sql new file mode 100644 index 0000000..b9bf7d4 --- /dev/null +++ b/web/app/initdb_mysql.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS notes ( + id INT AUTO_INCREMENT PRIMARY KEY, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS action_items ( + id INT AUTO_INCREMENT PRIMARY KEY, + note_id INT, + text TEXT NOT NULL, + done TINYINT(1) DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (note_id) REFERENCES notes(id) +); + +CREATE TABLE IF NOT EXISTS opm_diagrams ( + id INT AUTO_INCREMENT PRIMARY KEY, + note_id INT, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); diff --git a/web/app/initdb_sqlite.sql b/web/app/initdb_sqlite.sql new file mode 100644 index 0000000..d872489 --- /dev/null +++ b/web/app/initdb_sqlite.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS notes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS action_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + note_id INTEGER, + text TEXT NOT NULL, + done INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (note_id) REFERENCES notes(id) +); + +CREATE TABLE IF NOT EXISTS opm_diagrams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + note_id INTEGER, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); diff --git a/web/data/app.db b/web/data/app.db index dc9aeebec0cb2510b6e61b8e076698bb47a25b24..ea4fb7aec0ff01fd6bcbc0dae6f70166a7889e50 100644 GIT binary patch delta 144 zcmZo@;Am*zm>|vgVWNyPypxu`!2> wrIEQgWAa1CGC@;aBLf9PGbcv6SmL#!=0.23.0 pydantic>=2.0.0 openai>=1.0.0 python-dotenv>=1.0.0 +sqlalchemy>=2.0.0 +pymysql>=1.1.0 diff --git a/web/tests/test_opm.py b/web/tests/test_opm.py index c0f0b1c..86fee25 100644 --- a/web/tests/test_opm.py +++ b/web/tests/test_opm.py @@ -2,7 +2,6 @@ import json import sys -from pathlib import Path from typing import Generator from unittest.mock import MagicMock, patch @@ -236,12 +235,18 @@ def test_version_preserved_in_round_trip(): @pytest.fixture() -def tmp_db(tmp_path: Path) -> Generator[Path, None, None]: - db_file = tmp_path / "test.db" - with patch("web.app.db.DB_PATH", db_file): - from web.app import db as db_module +def tmp_db(): + from sqlalchemy import create_engine as _create_engine + from sqlalchemy.pool import StaticPool + test_engine = _create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + with patch("web.app.db.engine", test_engine): + import web.app.db as db_module db_module.init_db() - yield db_file + yield test_engine # --------------------------------------------------------------------------- @@ -249,14 +254,14 @@ def tmp_db(tmp_path: Path) -> Generator[Path, None, None]: # --------------------------------------------------------------------------- -def test_insert_opm_diagram_returns_int(tmp_db: Path): +def test_insert_opm_diagram_returns_int(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [], "links": []} diagram_id = db_module.insert_opm_diagram(payload) assert isinstance(diagram_id, int) -def test_insert_opm_diagram_multiple_distinct_ids(tmp_db: Path): +def test_insert_opm_diagram_multiple_distinct_ids(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [], "links": []} id1 = db_module.insert_opm_diagram(payload) @@ -264,7 +269,7 @@ def test_insert_opm_diagram_multiple_distinct_ids(tmp_db: Path): assert id1 != id2 -def test_get_opm_diagram_returns_parsed_dict(tmp_db: Path): +def test_get_opm_diagram_returns_parsed_dict(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [{"id": "x"}], "links": []} diagram_id = db_module.insert_opm_diagram(payload) @@ -274,12 +279,12 @@ def test_get_opm_diagram_returns_parsed_dict(tmp_db: Path): assert row["diagram"]["version"] == "1.0" -def test_get_opm_diagram_not_found_returns_none(tmp_db: Path): +def test_get_opm_diagram_not_found_returns_none(tmp_db): from web.app import db as db_module assert db_module.get_opm_diagram(99999) is None -def test_note_id_null_when_not_provided(tmp_db: Path): +def test_note_id_null_when_not_provided(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [], "links": []} diagram_id = db_module.insert_opm_diagram(payload) @@ -288,7 +293,7 @@ def test_note_id_null_when_not_provided(tmp_db: Path): assert row["note_id"] is None -def test_list_opm_diagrams_returns_all(tmp_db: Path): +def test_list_opm_diagrams_returns_all(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [], "links": []} db_module.insert_opm_diagram(payload) @@ -297,7 +302,7 @@ def test_list_opm_diagrams_returns_all(tmp_db: Path): assert len(diagrams) == 2 -def test_stored_payload_matches_inserted(tmp_db: Path): +def test_stored_payload_matches_inserted(tmp_db): from web.app import db as db_module payload = {"version": "1.0", "nodes": [{"id": "n1"}], "links": []} diagram_id = db_module.insert_opm_diagram(payload) @@ -312,23 +317,19 @@ def test_stored_payload_matches_inserted(tmp_db: Path): @pytest.fixture() -def client(tmp_db: Path) -> Generator[TestClient, None, None]: - # ollama is not installed in this environment; stub it so the app can be imported +def client(tmp_db) -> Generator[TestClient, None, None]: sys.modules.setdefault("ollama", MagicMock()) - # Remove cached app import so the patched DB_PATH takes effect - for mod in list(sys.modules): - if mod.startswith("web.app"): - del sys.modules[mod] from web.app.main import app - with patch( - "web.app.services.opm_extract.call_llm", side_effect=fake_opm_llm_success - ): - with TestClient(app) as c: - yield c + with patch("web.app.db.engine", tmp_db): + with patch( + "web.app.services.opm_extract.call_llm", side_effect=fake_opm_llm_success + ): + with TestClient(app) as c: + yield c -def test_post_extract_inserts_row(client: TestClient, tmp_db: Path): +def test_post_extract_inserts_row(client: TestClient, tmp_db): from web.app import db as db_module response = client.post("/opm/extract", json={"text": "some text", "save_note": False}) assert response.status_code == 200 diff --git a/web/tests/test_opm_validate.py b/web/tests/test_opm_validate.py index 34283c4..96922b0 100644 --- a/web/tests/test_opm_validate.py +++ b/web/tests/test_opm_validate.py @@ -2,7 +2,6 @@ import logging import sys -from pathlib import Path from typing import Generator from unittest.mock import MagicMock, patch @@ -355,38 +354,47 @@ def test_duplicate_relation_warns(caplog): @pytest.fixture() -def client() -> Generator: +def _test_engine(): + from sqlalchemy import create_engine as _create_engine + from sqlalchemy.pool import StaticPool + test_engine = _create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + with patch("web.app.db.engine", test_engine): + import web.app.db as db_module + db_module.init_db() + yield test_engine + + +@pytest.fixture() +def client(_test_engine) -> Generator: sys.modules.setdefault("ollama", MagicMock()) from fastapi.testclient import TestClient from web.app.main import app - with patch( - "web.app.services.opm_extract.call_llm", side_effect=fake_opm_llm_success - ): - with TestClient(app) as c: - yield c + with patch("web.app.db.engine", _test_engine): + with patch( + "web.app.services.opm_extract.call_llm", side_effect=fake_opm_llm_success + ): + with TestClient(app) as c: + yield c -def test_valid_stub_passes_validation_and_stores(client, tmp_path): - with patch("web.app.db.DB_PATH", tmp_path / "test.db"): - from web.app import db as db_module - db_module.init_db() - response = client.post("/opm/extract", json={"text": "some text", "save_note": False}) +def test_valid_stub_passes_validation_and_stores(client): + response = client.post("/opm/extract", json={"text": "some text", "save_note": False}) assert response.status_code == 200 data = response.json() assert data["diagram"]["version"] == "1.0" -def test_invalid_diagram_blocked_before_db_insert(client, tmp_path, monkeypatch): +def test_invalid_diagram_blocked_before_db_insert(client, monkeypatch): """An invalid dict from extraction must be rejected with 422, not stored.""" bad_diagram = {"version": "1.0", "nodes": [{"id": "BadID", "kind": "object", "label": "X"}], "links": []} - # Patch the name as imported in the router module monkeypatch.setattr("web.app.services.opm_extract.extract_opm_diagram", lambda text: bad_diagram) - with patch("web.app.db.DB_PATH", tmp_path / "test.db"): - from web.app import db as db_module - db_module.init_db() - response = client.post("/opm/extract", json={"text": "some text", "save_note": False}) + response = client.post("/opm/extract", json={"text": "some text", "save_note": False}) assert response.status_code == 422 body = response.json() @@ -394,15 +402,13 @@ def test_invalid_diagram_blocked_before_db_insert(client, tmp_path, monkeypatch) assert body["detail"]["error"] == "opm_extraction_failed" -def test_invalid_diagram_not_persisted(client, tmp_path, monkeypatch): +def test_invalid_diagram_not_persisted(client, _test_engine, monkeypatch): """After a validation failure, no row should appear in opm_diagrams.""" bad_diagram = {"version": "1.0", "nodes": [{"id": "Bad", "kind": "object", "label": "X"}], "links": []} monkeypatch.setattr("web.app.services.opm_extract.extract_opm_diagram", lambda text: bad_diagram) - with patch("web.app.db.DB_PATH", tmp_path / "test.db"): - from web.app import db as db_module - db_module.init_db() - client.post("/opm/extract", json={"text": "some text", "save_note": False}) - diagrams = db_module.list_opm_diagrams() + client.post("/opm/extract", json={"text": "some text", "save_note": False}) + import web.app.db as db_module + diagrams = db_module.list_opm_diagrams() assert diagrams == [] From 3968edb281053963eed3dbf9708adc03f146257e Mon Sep 17 00:00:00 2001 From: Candice0313 Date: Fri, 22 May 2026 13:57:58 -0500 Subject: [PATCH 2/2] add Docker Compose support for web service --- docker-compose.web.yml => docker-compose.yml | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) rename docker-compose.web.yml => docker-compose.yml (50%) diff --git a/docker-compose.web.yml b/docker-compose.yml similarity index 50% rename from docker-compose.web.yml rename to docker-compose.yml index 814e402..3607a93 100644 --- a/docker-compose.web.yml +++ b/docker-compose.yml @@ -1,27 +1,20 @@ -# VM / server deploy for OPM web app (testcase.md). Do not commit secrets. -# Usage: -# export OPM_MODEL=your-model -# export OPENAI_API_KEY=your-key -# docker compose -f docker-compose.web.yml up -d --build -# -# SQLite persists in Docker volume `opm_web_data`. UI: http://HOST:8000/ - services: - opm-web: + web: build: context: . dockerfile: Dockerfile.web ports: - - "8000:8000" + - "3000:8000" environment: WEB_DATA_DIR: /data + DEPLOY_TYPE: ${DEPLOY_TYPE:-DEV} OPM_MODEL: ${OPM_MODEL:-} OPENAI_API_KEY: ${OPENAI_API_KEY:-} OPENAI_BASE_URL: ${OPENAI_BASE_URL:-} OPM_OPENAI_EXTRA_HEADERS: ${OPM_OPENAI_EXTRA_HEADERS:-} OPM_TOP_P: ${OPM_TOP_P:-} volumes: - - opm_web_data:/data + - web_data:/data volumes: - opm_web_data: + web_data: