From 0fdc63239272ddf985172e80fa22b2792f133daf Mon Sep 17 00:00:00 2001 From: Caleb Muthama Date: Tue, 31 Mar 2026 11:16:41 +0300 Subject: [PATCH 1/3] feat: introduced report-centric schema and abstraction layer --- api/db/database.py | 6 +- api/db/init_db.py | 4 +- api/db/models.py | 42 +++- api/db/repositories.py | 278 ++++++++++++++++++++++++- api/schemas/report_class.py | 84 ++++++++ tests/unit/test_repositories.py | 346 ++++++++++++++++++++++++++++++++ 6 files changed, 747 insertions(+), 13 deletions(-) create mode 100644 api/schemas/report_class.py create mode 100644 tests/unit/test_repositories.py diff --git a/api/db/database.py b/api/db/database.py index 7943947..a117252 100644 --- a/api/db/database.py +++ b/api/db/database.py @@ -1,10 +1,12 @@ from sqlmodel import create_engine, Session +import os -DATABASE_URL = "sqlite:///./fireform.db" +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'fireform.db')}" engine = create_engine( DATABASE_URL, - echo=True, + echo=False, connect_args={"check_same_thread": False}, ) diff --git a/api/db/init_db.py b/api/db/init_db.py index 9ad27ea..c868db4 100644 --- a/api/db/init_db.py +++ b/api/db/init_db.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel -from api.db.database import engine -from api.db import models +from database import engine +import models def init_db(): SQLModel.metadata.create_all(engine) diff --git a/api/db/models.py b/api/db/models.py index f76c93b..bbbbdff 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -1,10 +1,11 @@ -from sqlmodel import SQLModel, Field +from sqlmodel import SQLModel, Field, UniqueConstraint from sqlalchemy import Column, JSON from datetime import datetime +from enum import Enum class Template(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) - name: str + name: str = Field(unique=True) fields: dict = Field(sa_column=Column(JSON)) pdf_path: str created_at: datetime = Field(default_factory=datetime.utcnow) @@ -15,4 +16,39 @@ class FormSubmission(SQLModel, table=True): template_id: int input_text: str output_pdf_path: str - created_at: datetime = Field(default_factory=datetime.utcnow) \ No newline at end of file + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class ReportSchema(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str = Field(unique=True) + description: str + use_case: str + created_at: datetime = Field(default_factory=datetime.utcnow) + +class ReportSchemaTemplate(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + template_id: int + report_schema_id: int + field_mapping: dict = Field(default={}, sa_column=Column(JSON)) + + __table_args__ = (UniqueConstraint("template_id", "report_schema_id"),) + +class Datatype(str, Enum): + STRING = "string" + INT = "int" + DATE = "date" + ENUM = 'enum' + + +class SchemaField(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + report_schema_id: int + field_name: str + source_template_id: int + description: str = Field(default="") + data_type: Datatype = Field(default=Datatype.STRING) + word_limit: int | None = Field(default=None) + required: bool = Field(default=False) + allowed_values: dict | None = Field(sa_column=Column(JSON)) + canonical_name: str | None = Field(default=None) diff --git a/api/db/repositories.py b/api/db/repositories.py index 6608718..2a59e21 100644 --- a/api/db/repositories.py +++ b/api/db/repositories.py @@ -1,19 +1,285 @@ +from ast import For +from collections import defaultdict +from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select -from api.db.models import Template, FormSubmission +from api.db.models import ( + Template, + FormSubmission, + ReportSchema, + ReportSchemaTemplate, + SchemaField, +) -# Templates def create_template(session: Session, template: Template) -> Template: + try: + session.add(template) + session.commit() + session.refresh(template) + return template + except IntegrityError: + raise + +def get_template(session: Session, template_id: int) -> Template | None: + return session.get(Template, template_id) + +def update_template(session: Session, template_id: int, updates: dict) -> Template | None: + template = session.get(Template, template_id) + if not template: + return None + for key, value in updates.items(): + setattr(template, key, value) session.add(template) session.commit() session.refresh(template) return template -def get_template(session: Session, template_id: int) -> Template | None: - return session.get(Template, template_id) +def list_templates(session: Session) -> list[Template]: + return session.exec(select(Template)).all() + +def delete_template(session: Session, template_id: int) -> bool: + """Remove template and dependent rows (form submissions, schema links, schema fields).""" + template = session.get(Template, template_id) + if not template: + return False + + for form in session.exec( + select(FormSubmission).where(FormSubmission.template_id == template_id) + ).all(): + session.delete(form) + + for junction in session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.template_id == template_id + ) + ).all(): + for field in session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == junction.report_schema_id, + SchemaField.source_template_id == template_id, + ) + ).all(): + session.delete(field) + session.delete(junction) + + session.delete(template) + session.commit() + return True -# Forms def create_form(session: Session, form: FormSubmission) -> FormSubmission: session.add(form) session.commit() session.refresh(form) - return form \ No newline at end of file + return form + +def get_form(session: Session, form_id: int) -> FormSubmission: + return session.get(FormSubmission, form_id) + +def update_form(session: Session, form_id: int, updates: dict) -> FormSubmission | None: + form = session.get(FormSubmission, form_id) + if not form: + return None + for key, value in updates.items(): + setattr(form, key, value) + session.add(form) + session.commit() + session.refresh(form) + return form + +def delete_form(session: Session, form_id: int) -> FormSubmission: + form_submission = session.get(FormSubmission, form_id) + if form_submission: + session.delete(form_submission) + session.commit() + return True + return False + +def create_report_schema(session: Session, schema: ReportSchema) -> ReportSchema: + try: + session.add(schema) + session.commit() + session.refresh(schema) + return schema + except IntegrityError: + raise + +def get_report_schema(session: Session, schema_id: int) -> ReportSchema | None: + return session.get(ReportSchema, schema_id) + +def list_report_schemas(session: Session) -> list[ReportSchema]: + return session.exec(select(ReportSchema)).all() + +def update_report_schema(session: Session, schema_id: int, updates: dict) -> ReportSchema | None: + schema = session.get(ReportSchema, schema_id) + if not schema: + return None + for key, value in updates.items(): + setattr(schema, key, value) + session.add(schema) + session.commit() + session.refresh(schema) + return schema + +def delete_report_schema(session: Session, schema_id: int) -> bool: + schema = session.get(ReportSchema, schema_id) + if not schema: + return False + + fields = session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + for field in fields: + session.delete(field) + + junctions = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + for junction in junctions: + session.delete(junction) + + session.delete(schema) + session.commit() + return True + + +def add_template_to_schema( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate: + """Associate a template with a schema. + + Looks up `template.fields` and auto-creates a SchemaField for each field, + pre-populated with `field_name` and `source_template_id`. + Other metadata is left as defaults for the user to fill in later. + """ + template = session.get(Template, template_id) + if not template: + raise ValueError(f"Template {template_id} not found") + + schema = session.get(ReportSchema, schema_id) + if not schema: + raise ValueError(f"ReportSchema {schema_id} not found") + + # exists = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + # if exists: + # raise IntegrityError + + # Create the junction record (field_mapping starts empty, populated during canonization) + junction = ReportSchemaTemplate( + report_schema_id=schema_id, + template_id=template_id, + ) + + session.add(junction) + + # Auto-create a SchemaField for each field in the template + for field_name in template.fields: + schema_field = SchemaField( + report_schema_id=schema_id, + field_name=field_name, + source_template_id=template_id, + ) + session.add(schema_field) + + session.commit() + session.refresh(junction) + return junction + +def remove_template_from_schema( + session: Session, schema_id: int, template_id: int +) -> bool: + """Disassociate a template from a schema and remove its SchemaField entries.""" + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return False + + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + for field in fields: + session.delete(field) + + session.delete(junction) + session.commit() + return True + + +def get_schema_fields(session: Session, schema_id: int) -> list[SchemaField]: + return session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + +def get_schema_field(session: Session, field_id: int) -> SchemaField: + return session.get(SchemaField, field_id) + +def update_schema_field(session: Session, schema_id: int, field_id: int, updates: dict) -> SchemaField | None: + """Update field metadata: description, data_type, word_limit, required, allowed_values. + + Validates that the field belongs to the given schema before updating, + so the same template field in different schemas can have independent metadata. + """ + field = session.get(SchemaField, field_id) + if not field or field.report_schema_id != schema_id: + return None + for key, value in updates.items(): + setattr(field, key, value) + session.add(field) + session.commit() + session.refresh(field) + return field + + +# ── Template mapping (post-canonization) ───────────────────────────────────── + +def update_template_mapping( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate | None: + """Auto-generate and store the canonical → PDF field mapping after canonization. + + Builds the mapping by looking up all SchemaFields for this schema+template pair + and mapping each field's canonical_name → field_name. + """ + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return None + + # Build mapping from SchemaFields that have been canonized + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + + grouped: defaultdict[str, list[str]] = defaultdict(list) + for field in sorted(fields, key=lambda f: f.field_name): + key = field.canonical_name if field.canonical_name else field.field_name + grouped[key].append(field.field_name) + + # One PDF field -> store str; several sharing a canonical -> list (distribute handles both). + field_mapping: dict = {} + for key, names in grouped.items(): + field_mapping[key] = names[0] if len(names) == 1 else names + + junction.field_mapping = field_mapping + session.add(junction) + session.commit() + session.refresh(junction) + return junction + +def get_field_mapping(session: Session, schema_id: int, template_id: int) -> ReportSchemaTemplate: + junction = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + return junction.field_mapping \ No newline at end of file diff --git a/api/schemas/report_class.py b/api/schemas/report_class.py new file mode 100644 index 0000000..ebc3776 --- /dev/null +++ b/api/schemas/report_class.py @@ -0,0 +1,84 @@ +from pydantic import BaseModel +from datetime import datetime +from api.db.models import Datatype + + +class ReportSchemaCreate(BaseModel): + name: str + description: str + use_case: str + +class ReportSchemaUpdate(BaseModel): + name: str | None = None + description: str | None = None + use_case: str | None = None + +class TemplateAssociation(BaseModel): + template_id: int + +class ReportFill(BaseModel): + input_text: str + +class ReportFillResponse(BaseModel): + schema_id: int + input_text: str + output_pdf_paths: list[str] + +class SchemaFieldUpdate(BaseModel): + description: str | None = None + data_type: Datatype | None = None + word_limit: int | None = None + required: bool | None = None + allowed_values: dict | None = None + canonical_name: str | None = None + + +class SchemaFieldResponse(BaseModel): + id: int + report_schema_id: int + field_name: str + source_template_id: int + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + canonical_name: str | None + + class Config: + from_attributes = True + +class TemplateInSchema(BaseModel): + id: int + template_id: int + report_schema_id: int + field_mapping: dict + + class Config: + from_attributes = True + +class ReportSchemaResponse(BaseModel): + id: int + name: str + description: str + use_case: str + created_at: datetime + templates: list[TemplateInSchema] = [] + fields: list[SchemaFieldResponse] = [] + + class Config: + from_attributes = True + + +class CanonicalFieldEntry(BaseModel): + canonical_name: str + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + source_fields: list[SchemaFieldResponse] + +class CanonicalSchema(BaseModel): + report_schema_id: int + canonical_fields: list[CanonicalFieldEntry] diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py new file mode 100644 index 0000000..13ddb6e --- /dev/null +++ b/tests/unit/test_repositories.py @@ -0,0 +1,346 @@ +import sys +from pathlib import Path + +import pytest +from sqlmodel import SQLModel, Session, create_engine, select + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from api.db.models import Datatype, FormSubmission, ReportSchema, ReportSchemaTemplate, SchemaField, Template +from api.db.repositories import ( + add_template_to_schema, + create_form, + create_report_schema, + create_template, + delete_form, + delete_report_schema, + delete_template, + get_form, + get_report_schema, + get_schema_fields, + get_template, + list_report_schemas, + remove_template_from_schema, + update_form, + update_report_schema, + update_schema_field, + update_template, + update_template_mapping, +) + + +test_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + + +@pytest.fixture(name="session") +def session_fixture(): + SQLModel.metadata.create_all(test_engine) + with Session(test_engine) as session: + yield session + SQLModel.metadata.drop_all(test_engine) + + +def _mk_schema(session: Session, name: str = "schema") -> ReportSchema: + return create_report_schema(session, ReportSchema(name=name, description=f"{name}-desc", use_case=f"{name}-use")) + + +def _mk_template(session: Session, name: str = "template", fields: dict | None = None) -> Template: + return create_template( + session, + Template(name=name, fields=fields if fields is not None else {"f1": "v1"}, pdf_path=f"{name}.pdf"), + ) + + +def test_create_get_update_and_delete_template(session: Session): + created = _mk_template(session, "t-main", {"a": "b"}) + + # test that creation is accurate + assert created.id is not None + assert created.name == "t-main" + assert created.fields == {"a": "b"} + assert created.pdf_path == "t-main.pdf" + + fetched = get_template(session, created.id) + + # test whether the fetched and created templates match + assert fetched is not None + assert fetched.id == created.id + assert fetched.name == "t-main" + assert fetched.fields == {"a": "b"} + assert fetched.pdf_path == "t-main.pdf" + + # test whether updates are persistent and are done correctly + _ = update_template(session, fetched.id, { "name" : "updated-name", "fields" :{"ua" : "ub"}, "pdf_path" : "t-updated.pdf"}) + updated = get_template(session, fetched.id) + assert updated is not None + assert updated.id == created.id + assert updated.name == "updated-name" + assert updated.fields == {"ua": "ub"} + assert updated.pdf_path == "t-updated.pdf" + + # test that deleting works and double deleting does not work + assert delete_template(session, fetched.id) is True + assert delete_template(session, fetched.id) is False + + # test that getting a template that does not exist does not work + assert get_template(session, 999999) is None + + +def test_delete_template_cascades_forms_and_schema_links(session: Session): + schema = _mk_schema(session, "s-cascade") + tpl = _mk_template(session, "t-cascade", {"a": "string", "b": "string"}) + add_template_to_schema(session, schema_id=schema.id, template_id=tpl.id) + assert len(get_schema_fields(session, schema.id)) == 2 + + form = create_form( + session, + FormSubmission( + template_id=tpl.id, + input_text="hi", + output_pdf_path="/out.pdf", + ), + ) + + assert delete_template(session, tpl.id) is True + assert get_template(session, tpl.id) is None + assert get_schema_fields(session, schema.id) == [] + assert get_form(session, form.id) is None + assert ( + session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.template_id == tpl.id + ) + ).first() + is None + ) + + +def test_create_get_update_and_delete_form_submission(session: Session): + form = FormSubmission(template_id=123, input_text="sample input", output_pdf_path="/tmp/out.pdf") + created = create_form(session, form) + + # test creation of form is correct + assert created.id is not None + assert created.template_id == 123 + assert created.input_text == "sample input" + assert created.output_pdf_path == "/tmp/out.pdf" + + fetched = get_form(session, created.id) + + # test whether the fetched and created forms match + assert fetched.id == created.id + assert fetched.template_id == 123 + assert fetched.input_text == "sample input" + assert fetched.output_pdf_path == "/tmp/out.pdf" + + # test whether updates are persistent and are done correctly + _ = update_form(session, fetched.id, { "template_id" : 321, "input_text" : "input sample", "output_pdf_path" : "t-updated.pdf"}) + updated = get_form(session, fetched.id) + assert updated is not None + assert updated.id == created.id + assert updated.template_id == 321 + assert updated.input_text == "input sample" + assert updated.output_pdf_path == "t-updated.pdf" + + # test that deletion works and double deletion does not work + assert delete_form(session, fetched.id) is True + assert delete_form(session, fetched.id) is False + + # test that getting a template that does not exist does not work + assert get_form(session, 999999) is None + + +def test_create_get_list_update_and_delete_report_schema(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + + # implicitly tests schema creation and directly tests fetching + fetched = get_report_schema(session, s1.id) + assert fetched is not None + assert fetched.id == s1.id + assert fetched.name == "s1" + assert fetched.description == "s1-desc" + assert fetched.use_case == "s1-use" + + # test getting a template that does not exist does not work + assert get_report_schema(session, 999999) is None + + # test listing all schemas works correctly + listed = list_report_schemas(session) + assert {s.name for s in listed} == {"s1", "s2"} + + # test updating a schema works correctly + updated = update_report_schema(session, s1.id, {"name": "s1-new", "use_case": "u-new"}) + assert updated is not None + assert updated.name == "s1-new" + assert updated.description == "s1-desc" + assert updated.use_case == "u-new" + + # test updating a schema that does not exist does not work + assert update_report_schema(session, 999999, {"name": "x"}) is None + + +def test_add_template_to_schema_creates_junction_and_schema_fields(session: Session): + schema = _mk_schema(session) + template = _mk_template(session, fields={"field1": "x", "field2": "y"}) + + _ = add_template_to_schema(session, schema.id, template.id) + junction = session.get(ReportSchemaTemplate, _.id) + + # assert that the junction was created and created with the correct details + assert junction is not None + assert junction.report_schema_id == schema.id + assert junction.template_id == template.id + + fields = get_schema_fields(session, schema.id) + + # assert correct number of fields were created + assert len(fields) == 2 + + # assert all fields were created using correct details + assert {f.field_name for f in fields} == {"field1", "field2"} + assert {f.source_template_id for f in fields} == {template.id} + assert {f.report_schema_id for f in fields} == {schema.id} + +def test_delete_report_schema_deletes_schema_fields_and_junctions(session: Session): + schema = _mk_schema(session, "cascade") + t1 = _mk_template(session, "t1", {"f1": "v1", "f2": "v2"}) + t2 = _mk_template(session, "t2", {"f3": "v3"}) + add_template_to_schema(session, schema.id, t1.id) + add_template_to_schema(session, schema.id, t2.id) + + assert len(get_schema_fields(session, schema.id)) == 3 + assert session.query(ReportSchemaTemplate).count() == 2 + + # test deletion works correctly and fields as well as juncitions are deleted + assert delete_report_schema(session, schema.id) is True + assert get_report_schema(session, schema.id) is None + assert get_schema_fields(session, schema.id) == [] + assert session.query(ReportSchemaTemplate).count() == 0 + + # test double deletion and deleting schemas that do not exist + assert delete_report_schema(session, schema.id) is False + assert delete_report_schema(session, 424242) is False + + +def test_add_template_to_schema_supports_empty_template_fields(session: Session): + schema = _mk_schema(session, "empty-fields") + template = _mk_template(session, "empty-template", fields={}) + + junction = add_template_to_schema(session, schema.id, template.id) + + assert junction.id is not None + assert get_schema_fields(session, schema.id) == [] + +def test_add_template_to_schema_raises_for_missing_template_or_schema(session: Session): + schema = _mk_schema(session, "schema-only") + template = _mk_template(session, "template-only") + + with pytest.raises(ValueError, match="Template .* not found"): + add_template_to_schema(session, schema.id, 999999) + + with pytest.raises(ValueError, match="ReportSchema .* not found"): + add_template_to_schema(session, 999999, template.id) + + +def test_add_template_to_schema_duplicate_association_creates_extra_rows(session: Session): + schema = _mk_schema(session, "dup-schema") + template = _mk_template(session, "dup-template", {"f1": "v1"}) + + add_template_to_schema(session, schema.id, template.id) + add_template_to_schema(session, schema.id, template.id) + + assert session.query(ReportSchemaTemplate).count() == 2 + fields = get_schema_fields(session, schema.id) + assert len(fields) == 2 + assert all(field.field_name == "f1" for field in fields) + + +def test_remove_template_from_schema_removes_only_target_template_rows(session: Session): + schema = _mk_schema(session, "remove") + t1 = _mk_template(session, "t1", {"a": "1"}) + t2 = _mk_template(session, "t2", {"b": "2"}) + add_template_to_schema(session, schema.id, t1.id) + add_template_to_schema(session, schema.id, t2.id) + + assert remove_template_from_schema(session, schema.id, t1.id) is True + + remaining_fields = get_schema_fields(session, schema.id) + assert len(remaining_fields) == 1 + assert remaining_fields[0].field_name == "b" + assert remaining_fields[0].source_template_id == t2.id + assert remove_template_from_schema(session, schema.id, t1.id) is False + assert remove_template_from_schema(session, 101010, 202020) is False + + +def test_get_schema_fields_returns_fields_for_only_given_schema(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + t1 = _mk_template(session, "t1", {"f1": "v1", "f2": "v2"}) + t2 = _mk_template(session, "t2", {"x": "y"}) + add_template_to_schema(session, s1.id, t1.id) + add_template_to_schema(session, s2.id, t2.id) + + s1_fields = get_schema_fields(session, s1.id) + assert len(s1_fields) == 2 + assert {f.field_name for f in s1_fields} == {"f1", "f2"} + + +def test_update_schema_field_updates_all_supported_metadata(session: Session): + schema = _mk_schema(session, "meta") + template = _mk_template(session, "meta-t", {"status": "draft"}) + add_template_to_schema(session, schema.id, template.id) + field = get_schema_fields(session, schema.id)[0] + + updates = { + "description": "Status of the workflow", + "data_type": Datatype.ENUM, + "word_limit": 3, + "required": True, + "allowed_values": {"values": ["draft", "final"]}, + "canonical_name": "status_canonical", + } + updated = update_schema_field(session, schema.id, field.id, updates) + assert updated is not None + + refreshed = session.get(SchemaField, field.id) + assert refreshed is not None + assert refreshed.description == updates["description"] + assert refreshed.data_type == updates["data_type"] + assert refreshed.word_limit == updates["word_limit"] + assert refreshed.required is True + assert refreshed.allowed_values == updates["allowed_values"] + assert refreshed.canonical_name == updates["canonical_name"] + + +def test_update_schema_field_returns_none_for_missing_or_mismatched_field(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + t = _mk_template(session, "t", {"f1": "v1"}) + add_template_to_schema(session, s1.id, t.id) + field = get_schema_fields(session, s1.id)[0] + + assert update_schema_field(session, s2.id, field.id, {"description": "x"}) is None + assert update_schema_field(session, s1.id, 999999, {"description": "x"}) is None + + +def test_update_template_mapping_uses_canonical_name_or_fallback_field_name(session: Session): + schema = _mk_schema(session, "mapping") + template = _mk_template(session, "mapping-t", {"f1": "v1", "f2": "v2"}) + add_template_to_schema(session, schema.id, template.id) + fields = sorted(get_schema_fields(session, schema.id), key=lambda f: f.field_name) + + update_schema_field(session, schema.id, fields[0].id, {"canonical_name": "canon_f1"}) + # fields[1] intentionally left without canonical_name to test fallback + + junction = update_template_mapping(session, schema.id, template.id) + assert junction is not None + assert junction.field_mapping == {"canon_f1": "f1", "f2": "f2"} + + +def test_update_template_mapping_returns_none_when_junction_missing(session: Session): + schema = _mk_schema(session, "missing-junction") + template = _mk_template(session, "missing-junction-t") + + # No call to add_template_to_schema, so no junction exists. + assert update_template_mapping(session, schema.id, template.id) is None From 206f2eaca869df5aadeab4acbe49d40a6b85af7d Mon Sep 17 00:00:00 2001 From: Caleb Muthama Date: Mon, 6 Apr 2026 12:53:59 +0300 Subject: [PATCH 2/3] feat: robust report management and configuration architecture + API --- .gitignore | 4 +- api/db/models.py | 2 + api/main.py | 22 +- api/routes/report_schemas.py | 178 +++++++++++ api/routes/templates.py | 129 +++++++- api/schemas/report_class.py | 15 + api/schemas/templates.py | 10 +- requirements.txt | 1 + src/controller.py | 24 +- src/file_manipulator.py | 80 ++++- src/filler.py | 47 ++- src/pdf_utils.py | 17 ++ src/report_schema.py | 129 ++++++++ tests/unit/test_controller.py | 195 ++++++++++++ tests/unit/test_file_manipulator.py | 333 +++++++++++++++++++++ tests/unit/test_filler.py | 134 +++++++++ tests/unit/test_report_schema_processor.py | 181 +++++++++++ 17 files changed, 1478 insertions(+), 23 deletions(-) create mode 100644 api/routes/report_schemas.py create mode 100644 src/pdf_utils.py create mode 100644 src/report_schema.py create mode 100644 tests/unit/test_controller.py create mode 100644 tests/unit/test_file_manipulator.py create mode 100644 tests/unit/test_filler.py create mode 100644 tests/unit/test_report_schema_processor.py diff --git a/.gitignore b/.gitignore index 7fa2022..fa1ad7d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ .idea venv .venv -*.db \ No newline at end of file +*.db +template_files/ +.env \ No newline at end of file diff --git a/api/db/models.py b/api/db/models.py index bbbbdff..5ebb67d 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -14,6 +14,8 @@ class Template(SQLModel, table=True): class FormSubmission(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) template_id: int + report_schema_id: int | None = Field(default=None) + name: str | None = Field(default=None) input_text: str output_pdf_path: str created_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/api/main.py b/api/main.py index d0b8c79..61e63dc 100644 --- a/api/main.py +++ b/api/main.py @@ -1,7 +1,25 @@ from fastapi import FastAPI -from api.routes import templates, forms +from fastapi.middleware.cors import CORSMiddleware + +from api.routes import templates, forms, report_schemas +from api.errors.handlers import register_exception_handlers app = FastAPI() +register_exception_handlers(app) + +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://127.0.0.1:5173", + "http://localhost:5173", + "http://127.0.0.1:4173", + "http://localhost:4173", + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) app.include_router(templates.router) -app.include_router(forms.router) \ No newline at end of file +app.include_router(forms.router) +app.include_router(report_schemas.router) \ No newline at end of file diff --git a/api/routes/report_schemas.py b/api/routes/report_schemas.py new file mode 100644 index 0000000..08c3509 --- /dev/null +++ b/api/routes/report_schemas.py @@ -0,0 +1,178 @@ +from pathlib import Path +from sqlite3 import IntegrityError +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import FileResponse +from sqlmodel import Session, select +from api.deps import get_db +from api.schemas.report_class import ( + ReportSchemaCreate, + ReportSchemaUpdate, + ReportSchemaResponse, + TemplateAssociation, + SchemaFieldResponse, + SchemaFieldUpdate, + CanonicalSchema, + ReportFill, + ReportFillResponse, + FormSubmissionResponse, +) +from api.db import repositories as repo +from api.db.models import ReportSchema +from src.report_schema import ReportSchemaProcessor +from src.controller import Controller +from api.db.models import FormSubmission, ReportSchemaTemplate +from sqlalchemy.exc import IntegrityError + +router = APIRouter(prefix="/schemas", tags=["schemas"]) + + +@router.post("/create", response_model=ReportSchemaResponse) +def create_schema(data: ReportSchemaCreate, db: Session = Depends(get_db)): + schema = ReportSchema(**data.model_dump()) + try: + return repo.create_report_schema(db, schema) + except IntegrityError: + raise HTTPException( + status_code=409, + detail="A schema with this name already exists" + ) + +@router.get("/", response_model=list[ReportSchemaResponse]) +def list_schemas(db: Session = Depends(get_db)): + return repo.list_report_schemas(db) + +@router.get("/{schema_id}", response_model=ReportSchemaResponse) +def get_schema(schema_id: int, db: Session = Depends(get_db)): + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + return schema + +@router.put("/{schema_id}", response_model=ReportSchemaResponse) +def update_schema(schema_id: int, data: ReportSchemaUpdate, db: Session = Depends(get_db)): + updates = data.model_dump(exclude_none=True) + schema = repo.update_report_schema(db, schema_id, updates) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + return schema + +@router.delete("/{schema_id}") +def delete_schema(schema_id: int, db: Session = Depends(get_db)): + deleted = repo.delete_report_schema(db, schema_id) + if not deleted: + raise HTTPException(status_code=404, detail="Schema not found") + return {"detail": "Schema deleted"} + + +@router.post("/{schema_id}/templates", response_model=list[SchemaFieldResponse]) +def add_template(schema_id: int, data: TemplateAssociation, db: Session = Depends(get_db)): + """Associate a template with a schema. + + Auto-creates SchemaField entries from template.fields and returns them. + """ + try: + repo.add_template_to_schema(db, schema_id, data.template_id) + except IntegrityError: + raise HTTPException(status_code=409, detail="Template is already added to schema") + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + return repo.get_schema_fields(db, schema_id) + +@router.delete("/{schema_id}/templates/{template_id}") +def remove_template(schema_id: int, template_id: int, db: Session = Depends(get_db)): + removed = repo.remove_template_from_schema(db, schema_id, template_id) + if not removed: + raise HTTPException(status_code=404, detail="Template association not found") + return {"detail": "Template disassociated"} + + + +@router.get("/{schema_id}/fields", response_model=list[SchemaFieldResponse]) +def list_fields(schema_id: int, db: Session = Depends(get_db)): + return repo.get_schema_fields(db, schema_id) + +@router.put("/{schema_id}/fields/{field_id}", response_model=SchemaFieldResponse) +def update_field(schema_id: int, field_id: int, data: SchemaFieldUpdate, db: Session = Depends(get_db)): + updates = data.model_dump(exclude_none=True) + field = repo.update_schema_field(db, schema_id, field_id, updates) + if not field: + raise HTTPException(status_code=404, detail="Field not found or does not belong to this schema") + return field + + +@router.post("/{schema_id}/canonize", response_model=CanonicalSchema) +def canonize_schema(schema_id: int, db: Session = Depends(get_db)): + """Trigger canonization: group fields, assign canonical names, generate field mappings.""" + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + return ReportSchemaProcessor.canonize(db, schema_id) + +@router.get("/mapping/{schema_id}/{template_id}") +def get_schema_template_mapping(schema_id: int, template_id: int, db: Session = Depends(get_db)): + return repo.get_field_mapping(db, schema_id, template_id) + +@router.post("/{schema_id}/fill", response_model=ReportFillResponse) +def fill_schema(schema_id: int, data: ReportFill, db: Session = Depends(get_db)): + """ + End-to-end report generation. + Takes a single transcript, extracts canonical fields, distributes to + all schema templates, fills them, and logs the submissions. + """ + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + controller = Controller() + + output_paths = controller.fill_report(db, data.input_text, schema_id) + + # Log submissions + submission_ids: list[int] = [] + for template_id, path in output_paths.items(): + submission = FormSubmission( + template_id=template_id, + report_schema_id=schema_id, + name=data.name, + input_text=data.input_text, + output_pdf_path=path + ) + db.add(submission) + db.flush() + submission_ids.append(submission.id) # type: ignore + + db.commit() + + return ReportFillResponse( + schema_id=schema_id, + input_text=data.input_text, + output_pdf_paths=list(output_paths.values()), + submission_ids=submission_ids, + ) + + +@router.get("/{schema_id}/submissions", response_model=list[FormSubmissionResponse]) +def list_submissions(schema_id: int, db: Session = Depends(get_db)): + """List all form submissions for a given schema.""" + return db.exec( + select(FormSubmission) + .where(FormSubmission.report_schema_id == schema_id) + .order_by(FormSubmission.created_at.desc()) # type: ignore + ).all() + + +@router.get("/submissions/{submission_id}/pdf") +def get_submission_pdf(submission_id: int, db: Session = Depends(get_db)): + """Serve a filled PDF for a given form submission.""" + submission = db.get(FormSubmission, submission_id) + if not submission: + raise HTTPException(status_code=404, detail="Submission not found") + path = Path(submission.output_pdf_path).resolve() + if not path.is_file(): + raise HTTPException(status_code=404, detail="PDF file missing on disk") + return FileResponse( + path, + media_type="application/pdf", + filename=f"submission_{submission_id}.pdf", + ) diff --git a/api/routes/templates.py b/api/routes/templates.py index 5c2281b..dcacecc 100644 --- a/api/routes/templates.py +++ b/api/routes/templates.py @@ -1,16 +1,131 @@ -from fastapi import APIRouter, Depends +import re +from sqlalchemy.exc import IntegrityError +import uuid +from pathlib import Path + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi.responses import FileResponse from sqlmodel import Session + from api.deps import get_db -from api.schemas.templates import TemplateCreate, TemplateResponse -from api.db.repositories import create_template from api.db.models import Template +from api.db.repositories import ( + create_template, + delete_template, + get_template, + update_template, + list_templates +) +from api.schemas.templates import TemplateResponse, TemplateUpdate from src.controller import Controller router = APIRouter(prefix="/templates", tags=["templates"]) +INPUT_FILES_DIR = Path(__file__).resolve().parents[2] / "template_files" + + +def _safe_name_fragment(name: str) -> str: + base = Path(name).name + s = re.sub(r"[^\w\-.]+", "_", base.strip(), flags=re.UNICODE) + s = s.strip("._-") or "template" + return s[:120] + + @router.post("/create", response_model=TemplateResponse) -def create(template: TemplateCreate, db: Session = Depends(get_db)): +def create( + name: str = Form(...), + file: UploadFile = File(...), + db: Session = Depends(get_db), +): + filename = (file.filename or "").lower() + if not filename.endswith(".pdf"): + raise HTTPException(status_code=400, detail="File must be a .pdf") + + frag = _safe_name_fragment(name) + uid = uuid.uuid4().hex + INPUT_FILES_DIR.mkdir(parents=True, exist_ok=True) + dest = INPUT_FILES_DIR / f"{frag}_{uid}.pdf" + + raw = file.file.read() + if not raw: + raise HTTPException(status_code=400, detail="Empty file") + dest.write_bytes(raw) + controller = Controller() - template_path = controller.create_template(template.pdf_path) - tpl = Template(**template.model_dump(exclude={"pdf_path"}), pdf_path=template_path) - return create_template(db, tpl) \ No newline at end of file + try: + template_path = controller.create_template(str(dest)) + except Exception as e: + dest.unlink(missing_ok=True) + print(e) + raise HTTPException( + status_code=500, detail=f"Failed to prepare PDF template: {e}" + ) from e + + fields = controller.extract_template_fields(template_path) + tpl = Template(name=name.strip(), fields=fields, pdf_path=template_path) + + try: + return create_template(db, tpl) + except IntegrityError: + raise HTTPException( + status_code=409, + detail="A template with the same name already exists" + ) + +@router.get("/", response_model=list[Template]) +def list(db: Session = Depends(get_db)): + return list_templates(db) + + +@router.get("/{template_id}/pdf") +def get_template_pdf(template_id: int, db: Session = Depends(get_db)): + """Serve the stored PDF for preview in the schema wizard.""" + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + root = INPUT_FILES_DIR.resolve() + path = Path(tpl.pdf_path).resolve() + try: + path.relative_to(root) + except ValueError: + raise HTTPException(status_code=403, detail="Invalid template file location") + if not path.is_file(): + raise HTTPException(status_code=404, detail="PDF file missing on disk") + return FileResponse( + path, + media_type="application/pdf", + filename=f"{tpl.name}.pdf", + ) + + +@router.get("/{template_id}", response_model=TemplateResponse) +def get_one(template_id: int, db: Session = Depends(get_db)): + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + + +@router.put("/{template_id}", response_model=TemplateResponse) +def update_one( + template_id: int, + data: TemplateUpdate, + db: Session = Depends(get_db), +): + updates = data.model_dump(exclude_none=True) + if not updates: + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + tpl = update_template(db, template_id, updates) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + + +@router.delete("/{template_id}") +def delete_one(template_id: int, db: Session = Depends(get_db)): + if not delete_template(db, template_id): + raise HTTPException(status_code=404, detail="Template not found") + return {"detail": "Template deleted"} diff --git a/api/schemas/report_class.py b/api/schemas/report_class.py index ebc3776..986887d 100644 --- a/api/schemas/report_class.py +++ b/api/schemas/report_class.py @@ -3,6 +3,19 @@ from api.db.models import Datatype +class FormSubmissionResponse(BaseModel): + id: int + template_id: int + report_schema_id: int | None + name: str | None + input_text: str + output_pdf_path: str + created_at: datetime + + class Config: + from_attributes = True + + class ReportSchemaCreate(BaseModel): name: str description: str @@ -18,11 +31,13 @@ class TemplateAssociation(BaseModel): class ReportFill(BaseModel): input_text: str + name: str | None = None class ReportFillResponse(BaseModel): schema_id: int input_text: str output_pdf_paths: list[str] + submission_ids: list[int] = [] class SchemaFieldUpdate(BaseModel): description: str | None = None diff --git a/api/schemas/templates.py b/api/schemas/templates.py index 961f219..8ef630e 100644 --- a/api/schemas/templates.py +++ b/api/schemas/templates.py @@ -1,9 +1,11 @@ from pydantic import BaseModel -class TemplateCreate(BaseModel): - name: str - pdf_path: str - fields: dict + +class TemplateUpdate(BaseModel): + name: str | None = None + fields: dict | None = None + pdf_path: str | None = None + class TemplateResponse(BaseModel): id: int diff --git a/requirements.txt b/requirements.txt index eaa6c81..558ac18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pdfrw flask commonforms fastapi +python-multipart uvicorn pydantic sqlmodel diff --git a/src/controller.py b/src/controller.py index d31ec9c..a6bab12 100644 --- a/src/controller.py +++ b/src/controller.py @@ -1,4 +1,6 @@ from src.file_manipulator import FileManipulator +from sqlmodel import Session +from src.report_schema import ReportSchemaProcessor class Controller: def __init__(self): @@ -8,4 +10,24 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str): return self.file_manipulator.fill_form(user_input, fields, pdf_form_path) def create_template(self, pdf_path: str): - return self.file_manipulator.create_template(pdf_path) \ No newline at end of file + return self.file_manipulator.create_template(pdf_path) + + def extract_template_fields(self, pdf_path: str) -> dict[str, str]: + return self.file_manipulator.extract_template_field_map(pdf_path) + + def fill_report(self, session: Session, user_input: str, schema_id: int) -> dict[int, str]: + """ + Main pipeline entry point for filling a multi-template report schema. + 1. Triggers canonization to get the latest schema definition. + 2. Builds the JSON Schema extraction target for the LLM. + 3. Hands off to FileManipulator for actual processing. + """ + canonical_schema = ReportSchemaProcessor.canonize(session, schema_id) + extraction_target = ReportSchemaProcessor.build_extraction_target(canonical_schema) + + return self.file_manipulator.fill_report( + session=session, + user_input=user_input, + schema_id=schema_id, + canonical_target=extraction_target + ) \ No newline at end of file diff --git a/src/file_manipulator.py b/src/file_manipulator.py index b7815cc..9382fe7 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -1,7 +1,13 @@ import os +from pdfrw import PdfReader from src.filler import Filler from src.llm import LLM +from src.pdf_utils import decode_pdf_name from commonforms import prepare_form +from sqlmodel import Session +from src.report_schema import ReportSchemaProcessor +from api.db.models import Template + class FileManipulator: @@ -13,9 +19,23 @@ def create_template(self, pdf_path: str): """ By using commonforms, we create an editable .pdf template and we store it. """ - template_path = pdf_path[:-4] + "_template.pdf" - prepare_form(pdf_path, template_path) - return template_path + prepare_form(pdf_path, pdf_path) + return pdf_path + + def extract_template_field_map(self, pdf_path: str) -> dict[str, str]: + """AcroForm widget names from a PDF, each mapped to type ``string`` (Template.fields shape).""" + pdf = PdfReader(pdf_path) + names: list[str] = [] + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + for annot in page.Annots: + if getattr(annot, "Subtype", None) != "/Widget" or not getattr(annot, "T", None): + continue + raw = decode_pdf_name(str(annot.T).strip("() /")) + if raw and raw not in names: + names.append(raw) + return {n: "string" for n in names} def fill_form(self, user_input: str, fields: list, pdf_form_path: str): """ @@ -45,3 +65,57 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str): print(f"An error occurred during PDF generation: {e}") # Re-raise the exception so the frontend can handle it raise e + + def fill_report(self, session: Session, user_input: str, schema_id: int, canonical_target: dict) -> dict[int, str]: + """ + Extracts data using a canonical schema target, distributes the results + to all associated templates, and fills them by name. + """ + print(f"[1] Received report fill request for schema {schema_id}.") + print("[2] Starting canonical extraction process...") + + try: + # 1. Extract against the canonical target + self.llm.set_model_config(provider="gemini", model_name="gemini-2.5-flash") + self.llm._target_fields = canonical_target + self.llm._transcript_text = user_input + + print("Canonization Process Begins") + + extraction_target = ReportSchemaProcessor.canonize(session=session, schema_id=schema_id) + + print("Canonization Process Complete") + + canonical_data = self.llm.extractor(extraction_target) + + + print(f"[3] Canonical extraction complete. Distributing to templates...") + + # 2. Distribute to per-template dictionaries + distribution = ReportSchemaProcessor.distribute(session, schema_id, canonical_data) + + # 3. Fill each template + output_paths: dict[int, str] = {} + + for template_id, template_data in distribution.items(): + template = session.get(Template, template_id) + if not template or not os.path.exists(template.pdf_path): + print(f" -> Skipping template {template_id} (not found or missing PDF)") + continue + + print(f" -> Filling template {template_id} ({template.name})...") + output_name = self.filler.fill_form_by_name( + pdf_form=template.pdf_path, + field_values=template_data + ) + output_paths[template_id] = output_name + + print("\n----------------------------------") + print("✅ Report generation complete.") + print(f"Outputs saved to: {list(output_paths.values())}") + + return output_paths + + except Exception as e: + print(f"An error occurred during report generation: {e}") + raise e diff --git a/src/filler.py b/src/filler.py index e31e535..74f8b98 100644 --- a/src/filler.py +++ b/src/filler.py @@ -1,6 +1,8 @@ -from pdfrw import PdfReader, PdfWriter +from pdfrw import PdfReader, PdfWriter, PdfDict, PdfObject from src.llm import LLM +from src.pdf_utils import decode_pdf_name from datetime import datetime +import uuid class Filler: @@ -15,7 +17,7 @@ def fill_form(self, pdf_form: str, llm: LLM): output_pdf = ( pdf_form[:-4] + "_" - + datetime.now().strftime("%Y%m%d_%H%M%S") + + str(uuid.uuid4()) + "_filled.pdf" ) @@ -28,14 +30,14 @@ def fill_form(self, pdf_form: str, llm: LLM): # Read PDF pdf = PdfReader(pdf_form) - # Loop through pages + # Global index across all pages (visual order is per page, pages in document order). + i = 0 for page in pdf.pages: if page.Annots: sorted_annots = sorted( page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) ) - i = 0 for annot in sorted_annots: if annot.Subtype == "/Widget" and annot.T: if i < len(answers_list): @@ -43,10 +45,45 @@ def fill_form(self, pdf_form: str, llm: LLM): annot.AP = None i += 1 else: - # Stop if we run out of answers break PdfWriter().write(output_pdf, pdf) # Your main.py expects this function to return the path return output_pdf + + + def fill_form_by_name(self, pdf_form: str, field_values: dict[str, str]) -> str: + """ + Fill a PDF form with values from a dictionary mapped by field name. + Unlike `fill_form`, this does not rely on visual ordering, it relies on + the exact field name defined in the PDF template matching a key in `field_values`. + """ + output_pdf = ( + pdf_form[:-4] + + "_" + + str(uuid.uuid4()) + + "_filled.pdf" + ) + + # Read PDF + pdf = PdfReader(pdf_form) + + # Force generation of Appearance Streams so text is visible in standard viewers + if pdf.Root.AcroForm: + pdf.Root.AcroForm.update(PdfDict(NeedAppearances=PdfObject('true'))) + + # Loop through pages + for page in pdf.pages: + if page.Annots: + for annot in page.Annots: + if annot.Subtype == "/Widget" and annot.T: + field_name = decode_pdf_name(str(annot.T).strip("() /")) + + if field_name in field_values: + # Update the PDF annotation + annot.V = f"{field_values[field_name]}" + annot.AP = None + + PdfWriter().write(output_pdf, pdf) + return output_pdf diff --git a/src/pdf_utils.py b/src/pdf_utils.py new file mode 100644 index 0000000..c3ef2f3 --- /dev/null +++ b/src/pdf_utils.py @@ -0,0 +1,17 @@ +import re + +_PDF_ESCAPE_RE = re.compile(r'\\(\d{1,3}|[nrtbf()\\])') +_NAMED_ESCAPES = { + 'n': '\n', 'r': '\r', 't': '\t', 'b': '\b', 'f': '\f', + '(': '(', ')': ')', '\\': '\\', +} + + +def decode_pdf_name(raw: str) -> str: + """Decode all PDF literal-string escape sequences (ISO 32000 §7.3.4.2).""" + def _replace(m: re.Match) -> str: + s = m.group(1) + if s[0].isdigit(): + return chr(int(s, 8)) + return _NAMED_ESCAPES.get(s, s) + return _PDF_ESCAPE_RE.sub(_replace, raw) diff --git a/src/report_schema.py b/src/report_schema.py new file mode 100644 index 0000000..785792f --- /dev/null +++ b/src/report_schema.py @@ -0,0 +1,129 @@ +from sqlmodel import Session, select +from typing import Any +from api.db.models import SchemaField, ReportSchemaTemplate, Datatype +from api.schemas.report_class import CanonicalSchema, CanonicalFieldEntry, SchemaFieldResponse +from api.db.repositories import update_template_mapping, get_report_schema + + +class ReportSchemaProcessor: + @staticmethod + def canonize(session: Session, schema_id: int) -> CanonicalSchema: + """Group fields by their canonical names (falling back to original names).""" + schema = get_report_schema(session, schema_id) + if not schema: + raise ValueError(f"ReportSchema {schema_id} not found") + + # 1. Fetch all fields for this schema + fields = session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + + # 2. Group fields by their effective canonical name + groups: dict[str, list[SchemaField]] = {} + + for field in fields: + # The manual override rule: If no canonical name is set, use the raw field name + effective_name = field.canonical_name if field.canonical_name else field.field_name + + if effective_name not in groups: + groups[effective_name] = [] + groups[effective_name].append(field) + + # 3. Build the CanonicalSchema representation + canonical_fields = [] + for effective_name, source_fields in groups.items(): + # Use metadata from the first field in the group as the canonical metadata + # (In a more complex system, we might merge these or let the user elect a "primary" field) + primary = source_fields[0] + + canonical_fields.append( + CanonicalFieldEntry( + canonical_name=effective_name, + description=primary.description, + data_type=primary.data_type, + word_limit=primary.word_limit, + required=primary.required, + allowed_values=primary.allowed_values, + source_fields=[SchemaFieldResponse.model_validate(f) for f in source_fields] + ) + ) + + # 4. Update the junction tables so they know how to map back + # We need to do this per-template + template_ids = {f.source_template_id for f in fields} + for t_id in template_ids: + update_template_mapping(session, schema_id, t_id) + + return CanonicalSchema( + report_schema_id=schema_id, + canonical_fields=canonical_fields + ) + + @staticmethod + def build_extraction_target(canonical_schema: CanonicalSchema) -> dict[str, Any]: + """Convert the CanonicalSchema into a JSON schema dict for LLM function calling.""" + properties = {} + required = [] + + type_mapping = { + Datatype.STRING: "string", + Datatype.INT: "integer", + Datatype.DATE: "string", # Represent dates as strings for LLM + Datatype.ENUM: "string" # Enums are strings restricted by allowed_values + } + + for field in canonical_schema.canonical_fields: + field_def = { + "type": type_mapping.get(field.data_type, "string"), + "description": field.description + } + + if field.data_type == Datatype.ENUM and field.allowed_values and "values" in field.allowed_values: + field_def["enum"] = field.allowed_values["values"] + + if field.word_limit: + field_def["description"] += f" (Maximum {field.word_limit} words)" + + properties[field.canonical_name] = field_def + + if field.required: + required.append(field.canonical_name) + + return { + "type": "object", + "properties": properties, + "required": required + } + + @staticmethod + def distribute( + session: Session, schema_id: int, canonical_data: dict[str, Any] + ) -> dict[int, dict[str, Any]]: + """Map canonical extraction output back to individual template fields.""" + junctions = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + + distribution = {} + + for junction in junctions: + template_id = junction.template_id + mapping = junction.field_mapping or {} + + template_data = {} + for canonical_name, pdf_targets in mapping.items(): + if canonical_name not in canonical_data: + continue + names = ( + pdf_targets + if isinstance(pdf_targets, list) + else [pdf_targets] + ) + for pdf_field_name in names: + template_data[pdf_field_name] = canonical_data[canonical_name] + + distribution[template_id] = template_data + + return distribution diff --git a/tests/unit/test_controller.py b/tests/unit/test_controller.py new file mode 100644 index 0000000..e3cb634 --- /dev/null +++ b/tests/unit/test_controller.py @@ -0,0 +1,195 @@ +import sys +from pathlib import Path +from typing import Any + +import pytest +from sqlmodel import SQLModel, Session, create_engine + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from api.db.models import Datatype, FormSubmission, ReportSchema, SchemaField, Template +from api.db.repositories import ( + add_template_to_schema, + create_report_schema, + create_template, + get_schema_fields, + update_schema_field, +) +from src.controller import Controller +from src.report_schema import ReportSchemaProcessor + + +test_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + + +@pytest.fixture(name="session") +def session_fixture(): + SQLModel.metadata.create_all(test_engine) + with Session(test_engine) as session: + yield session + SQLModel.metadata.drop_all(test_engine) + + +class _FakeT2J: + def __init__(self, data: dict[str, Any]): + self._data = data + + def get_data(self) -> dict[str, Any]: + return self._data + + +def _setup_schema_with_two_templates(session: Session): + schema = create_report_schema( + session, + ReportSchema(name="s1", description="d1", use_case="u1"), + ) + + # Each template uses its own PDF field names, but both map to canonical + # names via SchemaField.canonical_name. + t1 = create_template( + session, + Template(name="t1", fields={"name_f1": "x", "age_f1": "y"}, pdf_path="t1.pdf"), + ) + t2 = create_template( + session, + Template(name="t2", fields={"name_f2": "x", "age_f2": "y"}, pdf_path="t2.pdf"), + ) + + add_template_to_schema(session=session, schema_id=schema.id, template_id=t1.id) + add_template_to_schema(session=session, schema_id=schema.id, template_id=t2.id) + + fields = get_schema_fields(session=session, schema_id=schema.id) + for f in fields: + if f.field_name in {"name_f1", "name_f2"}: + update_schema_field( + session, + schema_id=schema.id, + field_id=f.id, + updates={ + "canonical_name": "name", + "required": True, + "data_type": Datatype.STRING, + "word_limit": 10, + }, + ) + elif f.field_name in {"age_f1", "age_f2"}: + update_schema_field( + session, + schema_id=schema.id, + field_id=f.id, + updates={ + "canonical_name": "age", + "required": False, + "data_type": Datatype.INT, + }, + ) + else: + raise AssertionError(f"Unexpected field_name in fixture: {f.field_name}") + + return schema, t1, t2 + + +def test_controller_fill_report_uses_report_schema_processor_and_fills_templates(session: Session): + schema, t1, t2 = _setup_schema_with_two_templates(session) + + canonical_data = {"name": "Alice", "age": 30} + seen: dict[str, Any] = {} + + class FakeLLM: + def __init__(self): + self._transcript_text = None + self._target_fields = None + + def main_loop(self): + # Capture what FileManipulator passed to LLM. + seen["transcript_text"] = self._transcript_text + seen["target_fields"] = self._target_fields + return _FakeT2J(canonical_data) + + class FakeFiller: + def __init__(self): + self.calls: list[tuple[str, dict[str, str]]] = [] + + def fill_form_by_name(self, pdf_form: str, field_values: dict[str, str]) -> str: + self.calls.append((pdf_form, field_values)) + return f"{pdf_form}__filled" + + with pytest.MonkeyPatch.context() as mp: + mp.setattr("src.file_manipulator.os.path.exists", lambda _: True) + mp.setattr("src.file_manipulator.LLM", FakeLLM) + mp.setattr("src.file_manipulator.Filler", FakeFiller) + + controller = Controller() + # Act + output = controller.fill_report(session=session, user_input="hello world", schema_id=schema.id) + + # Assert output paths per template. + assert output == {t1.id: "t1.pdf__filled", t2.id: "t2.pdf__filled"} + + # Assert report schema processor output shape was passed to LLM. + target_fields = seen["target_fields"] + assert "properties" in target_fields + assert set(target_fields["properties"].keys()) == {"name", "age"} + assert target_fields["properties"]["age"]["type"] == "integer" + + # Required list should include canonical field 'name' only. + assert set(target_fields["required"]) == {"name"} + + assert seen["transcript_text"] == "hello world" + + +def test_controller_fill_form_delegates_to_file_manipulator(session: Session): + controller = Controller() + + seen: dict[str, Any] = {} + + def fake_fill_form(user_input: str, fields: list, pdf_form_path: str): + seen["user_input"] = user_input + seen["fields"] = fields + seen["pdf_form_path"] = pdf_form_path + return "/out.pdf" + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(controller.file_manipulator, "fill_form", fake_fill_form) + + out = controller.fill_form(user_input="u1", fields=["a", "b"], pdf_form_path="/in.pdf") + + assert out == "/out.pdf" + assert seen == {"user_input": "u1", "fields": ["a", "b"], "pdf_form_path": "/in.pdf"} + + +def test_controller_create_template_delegates_to_file_manipulator(session: Session): + controller = Controller() + + seen: dict[str, Any] = {} + + def fake_create_template(pdf_path: str): + seen["pdf_path"] = pdf_path + return "/out_template.pdf" + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(controller.file_manipulator, "create_template", fake_create_template) + + out = controller.create_template(pdf_path="/in.pdf") + + assert out == "/out_template.pdf" + assert seen == {"pdf_path": "/in.pdf"} + + +def test_controller_extract_template_fields_delegates_to_file_manipulator(): + controller = Controller() + + def fake_extract(pdf_path: str): + assert pdf_path == "/tpl.pdf" + return {"a": "string"} + + with pytest.MonkeyPatch.context() as mp: + mp.setattr( + controller.file_manipulator, + "extract_template_field_map", + fake_extract, + ) + out = controller.extract_template_fields("/tpl.pdf") + + assert out == {"a": "string"} + diff --git a/tests/unit/test_file_manipulator.py b/tests/unit/test_file_manipulator.py new file mode 100644 index 0000000..0ff1ad4 --- /dev/null +++ b/tests/unit/test_file_manipulator.py @@ -0,0 +1,333 @@ +import os +import sys +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from pdfrw import PdfReader +from sqlmodel import SQLModel, Session, create_engine + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from api.db.models import Datatype, ReportSchema, Template +from api.db.repositories import ( + add_template_to_schema, + create_report_schema, + create_template, + get_schema_fields, + update_schema_field, +) +from src.file_manipulator import FileManipulator +from src.report_schema import ReportSchemaProcessor + + +FORMS_DIR = Path(__file__).resolve().parent.parent / "forms" +FW4_PDF = FORMS_DIR / "fw4.pdf" +FW4_TEMPLATE_PDF = FORMS_DIR / "fw4_template.pdf" +I9_PDF = FORMS_DIR / "i-9.pdf" + +test_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + + +@pytest.fixture(name="session") +def session_fixture(): + SQLModel.metadata.create_all(test_engine) + with Session(test_engine) as session: + yield session + SQLModel.metadata.drop_all(test_engine) + + +def _clean_pdf_value(v: Any) -> str | None: + if v is None: + return None + s = str(v).strip() + if s.startswith("(") and s.endswith(")"): + s = s[1:-1] + return s.strip() + + +def _extract_widget_values(pdf_path: str) -> dict[str, str | None]: + pdf = PdfReader(pdf_path) + out: dict[str, str | None] = {} + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + for annot in page.Annots: + if getattr(annot, "Subtype", None) != "/Widget" or not getattr(annot, "T", None): + continue + field_name = str(annot.T).strip("() /") + out[field_name] = _clean_pdf_value(getattr(annot, "V", None)) + return out + + +def _get_pdf_widget_raw_field_names(pdf_path: str) -> list[str]: + """Raw widget names as `Filler.fill_form_by_name` matches them (order: first seen per page walk).""" + pdf = PdfReader(pdf_path) + names: list[str] = [] + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + for annot in page.Annots: + if getattr(annot, "Subtype", None) != "/Widget" or not getattr(annot, "T", None): + continue + raw = str(annot.T).strip("() /") + if raw and raw not in names: + names.append(raw) + return names + + +def _widget_names_in_fill_order(pdf_path: str) -> list[str]: + pdf = PdfReader(pdf_path) + order: list[str] = [] + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + sorted_annots = sorted(page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0]))) + for annot in sorted_annots: + if getattr(annot, "Subtype", None) == "/Widget" and getattr(annot, "T", None): + order.append(str(annot.T).strip("() /")) + return order + + +def _widget_values_in_fill_order(pdf_path: str) -> list[str | None]: + """`/V` per widget in the same order as `_widget_names_in_fill_order`.""" + pdf = PdfReader(pdf_path) + out: list[str | None] = [] + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + sorted_annots = sorted(page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0]))) + for annot in sorted_annots: + if getattr(annot, "Subtype", None) == "/Widget" and getattr(annot, "T", None): + out.append(_clean_pdf_value(getattr(annot, "V", None))) + return out + + +def _ensure_fw4_template_pdf() -> str: + assert FW4_PDF.is_file(), f"Missing fixture PDF: {FW4_PDF}" + fm = FileManipulator() + out = fm.create_template(str(FW4_PDF.resolve())) + assert out == str(FW4_TEMPLATE_PDF.resolve()) + return str(FW4_TEMPLATE_PDF.resolve()) + + +@pytest.mark.skipif(not FW4_PDF.is_file(), reason="tests/forms/fw4.pdf required") +def test_extract_template_field_map_matches_unique_widgets(): + template_pdf = _ensure_fw4_template_pdf() + fm = FileManipulator() + m = fm.extract_template_field_map(template_pdf) + assert m == {n: "string" for n in _get_pdf_widget_raw_field_names(template_pdf)} + + +def test_file_manipulator_fill_form_visual_order_matches_val_i(): + """ + Real prepare_form -> fw4_template.pdf; LLM mocked. Dict values follow widget order (val_0, val_1, …). + Filled PDF is written under tests/forms/. + """ + template_pdf = _ensure_fw4_template_pdf() + order = _widget_names_in_fill_order(template_pdf) + n = len(order) + assert n > 0 + + fake_llm_payload = {f"__slot_{i}": f"val_{i}" for i in range(n)} + + with patch("src.file_manipulator.LLM") as MockLLM: + mock_t2j = MagicMock() + mock_t2j.get_data.return_value = fake_llm_payload + MockLLM.return_value.main_loop.return_value = mock_t2j + fm = FileManipulator() + out_path = fm.fill_form( + user_input="synthetic transcript for testing", + fields=order, + pdf_form_path=template_pdf, + ) + + assert out_path and os.path.isfile(out_path) + assert Path(out_path).resolve().parent == FORMS_DIR.resolve() + + filled = _widget_values_in_fill_order(out_path) + assert len(filled) == n + for i in range(n): + assert filled[i] == f"val_{i}", f"position {i} field {order[i]!r}" + + +def test_file_manipulator_fill_report_distributes_val_i_including_shared_canonical(session: Session): + """ + Schema + fw4 template PDF; two PDF fields share one canonical name (same extracted value on both). + LLM mocked with val_0..val_{m-1} for each canonical key (sorted). Assert filled PDF matches distribute(). + """ + template_pdf = _ensure_fw4_template_pdf() + order = _widget_names_in_fill_order(template_pdf) + assert len(order) >= 8 + + # Two distinct widgets share this canonical key — both must receive the same value after fill. + shared_canon = "shared_canon_key" + i_a, i_b = 2, 5 + name_a, name_b = order[i_a], order[i_b] + assert name_a != name_b + + unique_names = _get_pdf_widget_raw_field_names(template_pdf) + fields_map = {w: "string" for w in unique_names} + schema = create_report_schema(session, ReportSchema(name="fw4-schema", description="d", use_case="u")) + tpl = create_template( + session, + Template(name="fw4-tpl", fields=fields_map, pdf_path=template_pdf), + ) + add_template_to_schema(session, schema.id, tpl.id) + + rows_by_name = {f.field_name: f for f in get_schema_fields(session, schema.id)} + update_schema_field( + session, schema.id, rows_by_name[name_a].id, {"canonical_name": shared_canon, "data_type": Datatype.STRING} + ) + update_schema_field( + session, schema.id, rows_by_name[name_b].id, {"canonical_name": shared_canon, "data_type": Datatype.STRING} + ) + + # A couple of non-sequential canonical names (indices disjoint from the shared pair). + extra = ("emp_alias", "tax_pin") + for j, idx in enumerate((6, 7)): + fn = order[idx] + if fn in (name_a, name_b): + continue + update_schema_field( + session, + schema.id, + rows_by_name[fn].id, + {"canonical_name": extra[j], "data_type": Datatype.STRING}, + ) + + canonical_schema = ReportSchemaProcessor.canonize(session, schema.id) + canonical_target = ReportSchemaProcessor.build_extraction_target(canonical_schema) + + sorted_entries = sorted(canonical_schema.canonical_fields, key=lambda e: e.canonical_name) + extracted = {e.canonical_name: f"val_{i}" for i, e in enumerate(sorted_entries)} + + with patch("src.file_manipulator.LLM") as MockLLM: + mock_t2j = MagicMock() + mock_t2j.get_data.return_value = extracted + MockLLM.return_value.main_loop.return_value = mock_t2j + fm = FileManipulator() + outputs = fm.fill_report( + session=session, + user_input="synthetic report transcript", + schema_id=schema.id, + canonical_target=canonical_target, + ) + + assert tpl.id in outputs + out_pdf = outputs[tpl.id] + assert os.path.isfile(out_pdf) + assert Path(out_pdf).resolve().parent == FORMS_DIR.resolve() + + distribution = ReportSchemaProcessor.distribute(session, schema.id, extracted) + expected_by_pdf = distribution[tpl.id] + + shared_val = extracted[shared_canon] + assert expected_by_pdf[name_a] == shared_val == expected_by_pdf[name_b] + + values = _extract_widget_values(out_pdf) + for pdf_field, expected in expected_by_pdf.items(): + assert values.get(pdf_field) == str(expected), pdf_field + + +@pytest.mark.skipif(not I9_PDF.is_file(), reason=f"Missing {I9_PDF} (add USCIS I-9 PDF to tests/forms)") +def test_file_manipulator_fill_report_two_pdfs_shared_canonical_cross_file(session: Session): + """ + One schema, fw4 + i-9 templates: a canonical name is assigned to one field on each PDF so both + receive the same extracted value. LLM mocked; assert both outputs under tests/forms/ match distribute(). + Uses prepared *_template.pdf paths and widget names from those files (commonforms output). + """ + assert FW4_PDF.is_file(), f"Missing {FW4_PDF}" + fm0 = FileManipulator() + fm0.create_template(str(FW4_PDF.resolve())) + fm0.create_template(str(I9_PDF.resolve())) + fw4_path = str(FW4_TEMPLATE_PDF.resolve()) + i9_path = str((I9_PDF.parent / (I9_PDF.stem + "_template.pdf")).resolve()) + + fw4_names = _get_pdf_widget_raw_field_names(fw4_path) + i9_names = _get_pdf_widget_raw_field_names(i9_path) + assert len(fw4_names) >= 6 + assert len(i9_names) >= 7 + + cross_canon = "cross_file_shared_value" + fw4_field_a = fw4_names[2] + fw4_field_b = fw4_names[5] + i9_field_a = i9_names[3] + i9_field_b = i9_names[6] + + fields_fw4 = {n: "string" for n in fw4_names} + fields_i9 = {n: "string" for n in i9_names} + + schema = create_report_schema( + session, ReportSchema(name="dual-pdf-schema", description="d", use_case="u") + ) + tpl_fw4 = create_template( + session, + Template(name="fw4-tpl", fields=fields_fw4, pdf_path=fw4_path), + ) + tpl_i9 = create_template( + session, + Template(name="i9-tpl", fields=fields_i9, pdf_path=i9_path), + ) + add_template_to_schema(session, schema.id, tpl_fw4.id) + add_template_to_schema(session, schema.id, tpl_i9.id) + + by_tpl_and_name: dict[tuple[int, str], object] = {} + for f in get_schema_fields(session, schema.id): + by_tpl_and_name[(f.source_template_id, f.field_name)] = f + + def _set_canon(tpl_id: int, field_name: str, canon: str) -> None: + sf = by_tpl_and_name[(tpl_id, field_name)] + update_schema_field( + session, + schema.id, + sf.id, + {"canonical_name": canon, "data_type": Datatype.STRING}, + ) + + # Same canonical on two different files -> one extraction value applied to both widgets. + _set_canon(tpl_fw4.id, fw4_field_a, cross_canon) + _set_canon(tpl_i9.id, i9_field_a, cross_canon) + + # Second cross-file pair (different canonical) to stress junction mapping per template. + cross_canon_2 = "cross_file_shared_value_2" + _set_canon(tpl_fw4.id, fw4_field_b, cross_canon_2) + _set_canon(tpl_i9.id, i9_field_b, cross_canon_2) + + canonical_schema = ReportSchemaProcessor.canonize(session, schema.id) + canonical_target = ReportSchemaProcessor.build_extraction_target(canonical_schema) + + sorted_entries = sorted(canonical_schema.canonical_fields, key=lambda e: e.canonical_name) + extracted = {e.canonical_name: f"val_{i}" for i, e in enumerate(sorted_entries)} + + with patch("src.file_manipulator.LLM") as MockLLM: + mock_t2j = MagicMock() + mock_t2j.get_data.return_value = extracted + MockLLM.return_value.main_loop.return_value = mock_t2j + fm = FileManipulator() + outputs = fm.fill_report( + session=session, + user_input="synthetic cross-form transcript", + schema_id=schema.id, + canonical_target=canonical_target, + ) + + assert tpl_fw4.id in outputs and tpl_i9.id in outputs + for tid in (tpl_fw4.id, tpl_i9.id): + p = outputs[tid] + assert os.path.isfile(p) + assert Path(p).resolve().parent == FORMS_DIR.resolve() + + distribution = ReportSchemaProcessor.distribute(session, schema.id, extracted) + + shared = extracted[cross_canon] + shared2 = extracted[cross_canon_2] + assert distribution[tpl_fw4.id][fw4_field_a] == shared == distribution[tpl_i9.id][i9_field_a] + assert distribution[tpl_fw4.id][fw4_field_b] == shared2 == distribution[tpl_i9.id][i9_field_b] + + for tid, tpl in (tpl_fw4.id, tpl_fw4), (tpl_i9.id, tpl_i9): + expected_by_pdf = distribution[tid] + values = _extract_widget_values(outputs[tid]) + for pdf_field, expected in expected_by_pdf.items(): + assert values.get(pdf_field) == str(expected), f"{tpl.name} {pdf_field!r}" diff --git a/tests/unit/test_filler.py b/tests/unit/test_filler.py new file mode 100644 index 0000000..cfc0963 --- /dev/null +++ b/tests/unit/test_filler.py @@ -0,0 +1,134 @@ +import os +import sys +from pathlib import Path +from shutil import copyfile +from typing import Any + +import pytest +from pdfrw import PdfReader + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from src.filler import Filler + + +BASE_DIR = Path(__file__).resolve().parent.parent +FW4_FIXTURE_PATH = BASE_DIR / "forms" / "fw4.pdf" + + +def _clean_pdf_value(v: Any) -> str | None: + if v is None: + return None + s = str(v).strip() + # pdfrw commonly wraps strings as "(...)". + if s.startswith("(") and s.endswith(")"): + s = s[1:-1] + return s.strip() + + +def _extract_widget_values(pdf_path: str) -> dict[str, str | None]: + pdf = PdfReader(pdf_path) + out: dict[str, str | None] = {} + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + for annot in page.Annots: + if getattr(annot, "Subtype", None) != "/Widget" or not getattr(annot, "T", None): + continue + field_name = str(annot.T).strip("() /") + out[field_name] = _clean_pdf_value(getattr(annot, "V", None)) + return out + + +def _get_widget_names_in_fill_order(pdf_path: str) -> list[str]: + """ + Replicates the fill order in `src.filler.Filler.fill_form`: + - For each page: sort page.Annots by (-Rect.y1, Rect.x0) + - Then fill only widgets with a non-empty annot.T, in that sorted order. + """ + pdf = PdfReader(pdf_path) + order: list[str] = [] + + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + + sorted_annots = sorted(page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0]))) + for annot in sorted_annots: + if getattr(annot, "Subtype", None) == "/Widget" and getattr(annot, "T", None): + order.append(str(annot.T).strip("() /")) + return order + + +def test_fill_form_by_name_updates_matching_widget_values_and_writes_pdf(tmp_path: Path): + filler = Filler() + + assert FW4_FIXTURE_PATH.exists(), f"Missing fixture: {FW4_FIXTURE_PATH}" + input_pdf = FW4_FIXTURE_PATH + + # Pick two widget names from the real PDF. + before_values = _extract_widget_values(str(input_pdf)) + widget_names = [k for k in before_values.keys() if before_values[k] is not None or before_values[k] is None] + assert len(widget_names) >= 2, "Fixture PDF does not have at least 2 widgets" + + name1, name2 = widget_names[0], widget_names[1] + out_value_1 = "UNIT_TEST_VALUE_1" + out_value_2 = "UNIT_TEST_VALUE_2" + + out_pdf_path = filler.fill_form_by_name( + pdf_form=str(input_pdf), + field_values={name1: out_value_1, name2: out_value_2}, + ) + + assert os.path.exists(out_pdf_path), f"Output PDF not created: {out_pdf_path}" + + after_values = _extract_widget_values(out_pdf_path) + assert after_values[name1] == out_value_1 + assert after_values[name2] == out_value_2 + + # Unmatched widgets should not be overwritten to our unit-test values. + unmatched = widget_names[2] + assert after_values[unmatched] not in {out_value_1, out_value_2} + + # Ensure NeedAppearances was forced (for visibility in viewers). + out_pdf = PdfReader(out_pdf_path) + if getattr(out_pdf.Root, "AcroForm", None): + need = getattr(out_pdf.Root.AcroForm, "NeedAppearances", None) + assert need is not None + + +def test_fill_form_assigns_answers_in_visual_order_and_stops_when_exhausted(tmp_path: Path): + filler = Filler() + + assert FW4_FIXTURE_PATH.exists(), f"Missing fixture: {FW4_FIXTURE_PATH}" + input_pdf = FW4_FIXTURE_PATH + + fill_order = _get_widget_names_in_fill_order(str(input_pdf)) + assert len(fill_order) >= 3, "Fixture PDF fill order is unexpectedly short" + + # Give fewer answers than widgets to ensure exhaustion behavior. + answers = ["UNIT_TEST_FILL_0", "UNIT_TEST_FILL_1"] + + class _FakeT2J: + def get_data(self) -> dict[str, str]: + # In `Filler.fill_form`, answers_list = list(textbox_answers.values()). + # Use insertion order to control which widgets get which value. + return {"a": answers[0], "b": answers[1]} + + class _FakeLLM: + def main_loop(self): + return _FakeT2J() + + out_pdf_path = filler.fill_form(pdf_form=str(input_pdf), llm=_FakeLLM()) + assert os.path.exists(out_pdf_path), f"Output PDF not created: {out_pdf_path}" + + values = _extract_widget_values(out_pdf_path) + + # First two widgets in fill order should be set. + assert values[fill_order[0]] == answers[0] + assert values[fill_order[1]] == answers[1] + + # The next widget should not receive a value from `answers`. + for next_widget in fill_order[2:]: + assert values[next_widget] not in set(answers) + diff --git a/tests/unit/test_report_schema_processor.py b/tests/unit/test_report_schema_processor.py new file mode 100644 index 0000000..43154ae --- /dev/null +++ b/tests/unit/test_report_schema_processor.py @@ -0,0 +1,181 @@ +import sys +from pathlib import Path + +import pytest +from sqlmodel import SQLModel, Session, create_engine + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from api.db.models import Datatype, ReportSchema, ReportSchemaTemplate, Template +from api.db.repositories import ( + add_template_to_schema, + create_report_schema, + create_template, + get_schema_fields, + update_schema_field, +) +from api.schemas.report_class import CanonicalFieldEntry, CanonicalSchema, SchemaFieldResponse +from src.report_schema import ReportSchemaProcessor + + +test_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + + +@pytest.fixture(name="session") +def session_fixture(): + SQLModel.metadata.create_all(test_engine) + with Session(test_engine) as session: + yield session + SQLModel.metadata.drop_all(test_engine) + + +def _setup_schema_with_two_templates(session: Session): + schema = create_report_schema(session, ReportSchema(name="s1", description="d1", use_case="u1")) + t1 = create_template(session, Template(name="t1", fields={"f1": "v1", "f2": "v2"}, pdf_path="t1.pdf")) + t2 = create_template(session, Template(name="t2", fields={"f3": "v3", "f4": "v4"}, pdf_path="t2.pdf")) + add_template_to_schema(session=session, template_id=t1.id, schema_id=schema.id) + add_template_to_schema(session=session, template_id=t2.id, schema_id=schema.id) + return schema, t1, t2 + + +def test_canonize_raises_for_missing_schema(session: Session): + with pytest.raises(ValueError, match="ReportSchema .* not found"): + ReportSchemaProcessor.canonize(session=session, schema_id=999999) + + +def test_canonize_groups_fields_and_updates_junction_mappings(session: Session): + schema, t1, t2 = _setup_schema_with_two_templates(session) + fields = sorted(get_schema_fields(session=session, schema_id=schema.id), key=lambda f: f.field_name) + assert len(fields) == 4 + + # Merge f1 + f2 into one canonical field. + update_schema_field(session, schema.id, fields[0].id, {"canonical_name": "person_name"}) + update_schema_field(session, schema.id, fields[1].id, {"canonical_name": "person_name"}) + + canonical = ReportSchemaProcessor.canonize(session=session, schema_id=schema.id) + canonical_by_name = {f.canonical_name: f for f in canonical.canonical_fields} + + assert canonical.report_schema_id == schema.id + assert set(canonical_by_name.keys()) == {"person_name", "f3", "f4"} + assert len(canonical_by_name["person_name"].source_fields) == 2 + + t1_junction = session.query(ReportSchemaTemplate).filter( + ReportSchemaTemplate.report_schema_id == schema.id, + ReportSchemaTemplate.template_id == t1.id, + ).one() + t2_junction = session.query(ReportSchemaTemplate).filter( + ReportSchemaTemplate.report_schema_id == schema.id, + ReportSchemaTemplate.template_id == t2.id, + ).one() + assert t1_junction.field_mapping == {"person_name": ["f1", "f2"]} + assert t2_junction.field_mapping == {"f3": "f3", "f4": "f4"} + + +def test_canonize_uses_raw_field_name_when_canonical_name_missing(session: Session): + schema, _, _ = _setup_schema_with_two_templates(session) + canonical = ReportSchemaProcessor.canonize(session=session, schema_id=schema.id) + names = {f.canonical_name for f in canonical.canonical_fields} + assert names == {"f1", "f2", "f3", "f4"} + + +def test_build_extraction_target_maps_types_required_enum_and_word_limit(): + canonical_schema = CanonicalSchema( + report_schema_id=1, + canonical_fields=[ + CanonicalFieldEntry( + canonical_name="name", + description="Patient name", + data_type=Datatype.STRING, + word_limit=4, + required=True, + allowed_values=None, + source_fields=[], + ), + CanonicalFieldEntry( + canonical_name="age", + description="Patient age", + data_type=Datatype.INT, + word_limit=None, + required=True, + allowed_values=None, + source_fields=[], + ), + CanonicalFieldEntry( + canonical_name="visit_date", + description="Visit date", + data_type=Datatype.DATE, + word_limit=None, + required=False, + allowed_values=None, + source_fields=[], + ), + CanonicalFieldEntry( + canonical_name="status", + description="Final status", + data_type=Datatype.ENUM, + word_limit=None, + required=False, + allowed_values={"values": ["draft", "final"]}, + source_fields=[], + ), + CanonicalFieldEntry( + canonical_name="enum_without_values", + description="Bad enum metadata", + data_type=Datatype.ENUM, + word_limit=None, + required=False, + allowed_values={"other": [1, 2]}, + source_fields=[], + ), + ], + ) + + target = ReportSchemaProcessor.build_extraction_target(canonical_schema) + + assert target["type"] == "object" + assert set(target["properties"].keys()) == { + "name", + "age", + "visit_date", + "status", + "enum_without_values", + } + assert set(target["required"]) == {"name", "age"} + assert target["properties"]["name"]["type"] == "string" + assert "Maximum 4 words" in target["properties"]["name"]["description"] + assert target["properties"]["age"]["type"] == "integer" + assert target["properties"]["visit_date"]["type"] == "string" + assert target["properties"]["status"]["type"] == "string" + assert target["properties"]["status"]["enum"] == ["draft", "final"] + assert "enum" not in target["properties"]["enum_without_values"] + + +def test_distribute_returns_per_template_payload_from_mapping(session: Session): + schema, t1, t2 = _setup_schema_with_two_templates(session) + fields = sorted(get_schema_fields(session=session, schema_id=schema.id), key=lambda f: f.field_name) + + update_schema_field(session, schema.id, fields[0].id, {"canonical_name": "a"}) + update_schema_field(session, schema.id, fields[1].id, {"canonical_name": "b"}) + update_schema_field(session, schema.id, fields[2].id, {"canonical_name": "x"}) + update_schema_field(session, schema.id, fields[3].id, {"canonical_name": "y"}) + ReportSchemaProcessor.canonize(session=session, schema_id=schema.id) + + canonical_data = {"a": "A", "b": "B", "x": "X", "ignored": "IGNORED"} + distribution = ReportSchemaProcessor.distribute(session, schema.id, canonical_data) + + assert distribution[t1.id] == {"f1": "A", "f2": "B"} + assert distribution[t2.id] == {"f3": "X"} + + +def test_distribute_handles_missing_junctions_or_empty_mappings(session: Session): + schema = create_report_schema(session, ReportSchema(name="s-empty", description="d", use_case="u")) + + assert ReportSchemaProcessor.distribute(session, schema.id, {"x": 1}) == {} + + # Create a template + junction, but no canonical mapping update. + t1 = create_template(session, Template(name="t-empty", fields={"f1": "v1"}, pdf_path="t-empty.pdf")) + add_template_to_schema(session, schema_id=schema.id, template_id=t1.id) + + # field_mapping defaults to {}, so distribution should contain empty payload for template. + distribution = ReportSchemaProcessor.distribute(session, schema.id, {"x": 1}) + assert distribution == {t1.id: {}} \ No newline at end of file From 550fb5690237d2236c555163dfe5a1e62c12989d Mon Sep 17 00:00:00 2001 From: Caleb Muthama Date: Tue, 7 Apr 2026 17:22:09 +0300 Subject: [PATCH 3/3] feat: syntactic+semantic validation-driven extraction pipeline --- src/file_manipulator.py | 2 +- src/llm.py | 413 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 390 insertions(+), 25 deletions(-) diff --git a/src/file_manipulator.py b/src/file_manipulator.py index 9382fe7..1886e32 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -76,7 +76,7 @@ def fill_report(self, session: Session, user_input: str, schema_id: int, canonic try: # 1. Extract against the canonical target - self.llm.set_model_config(provider="gemini", model_name="gemini-2.5-flash") + self.llm.set_model_config() self.llm._target_fields = canonical_target self.llm._transcript_text = user_input diff --git a/src/llm.py b/src/llm.py index 70937f9..9a9af2c 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,9 +1,16 @@ +from api.schemas.report_class import CanonicalSchema, CanonicalFieldEntry +from api.db.models import Datatype +from typing import Any import json import os import requests +from dotenv import load_dotenv +load_dotenv() class LLM: + model_config = None + def __init__(self, transcript_text=None, target_fields=None, json=None): if json is None: json = {} @@ -11,6 +18,181 @@ def __init__(self, transcript_text=None, target_fields=None, json=None): self._target_fields = target_fields # List, contains the template field. self._json = json # dictionary + if LLM.model_config is None: + LLM.set_model_config() + + @classmethod + def set_model_config(cls, provider: str = None, model_name: str = None): + """ + Configure the model settings for local or online inference globally for the class. + Falls back to environment variables LLM_PROVIDER and LLM_MODEL if not specified. + """ + provider = provider or os.getenv("LLM_PROVIDER", "ollama") + model_name = model_name or os.getenv("LLM_MODEL", "mistral") + + config = { + "provider": provider, + "model": model_name + } + + if provider == "ollama": + host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + config["url"] = f"{host}/api/chat" + elif provider == "gemini": + config["url"] = "https://generativelanguage.googleapis.com/v1beta/models" + config["api_key"] = os.getenv("GEMINI_API_KEY") + else: + raise ValueError(f"Unknown provider: {provider}") + + cls.model_config = config + + @classmethod + def inference(cls, messages: list[dict], format: str = "json") -> str: + """ + Run inference using the globally specified configuration. Returns the raw string response content. + """ + if cls.model_config is None: + cls.set_model_config() + + config = cls.model_config + provider = config.get("provider", "ollama") + model = config.get("model", "mistral") + url = config.get("url") + + if provider == "ollama": + payload = { + "model": model, + "messages": messages, + "stream": False + } + if format == "json": + payload["format"] = "json" + + response = requests.post(url, json=payload) + response.raise_for_status() + return response.json()["message"]["content"] + + elif provider == "gemini": + api_key = config.get("api_key") + if not api_key: + raise ValueError("Gemini API key not found in GEMINI_API_KEY") + + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = msg["role"] + content = msg["content"] + + if role == "system": + system_instruction = content + continue + + gemini_role = "model" if role == "assistant" else "user" + gemini_contents.append({ + "role": gemini_role, + "parts": [{"text": content}] + }) + + payload = {"contents": gemini_contents} + + if system_instruction: + payload["system_instruction"] = { + "parts": [{"text": system_instruction}] + } + + if format == "json": + payload["generationConfig"] = { + "responseMimeType": "application/json" + } + + gemini_url = f"{url}/{model}:generateContent?key={api_key}" + response = requests.post(gemini_url, json=payload) + response.raise_for_status() + return response.json()["candidates"][0]["content"]["parts"][0]["text"] + + else: + raise ValueError(f"Unknown provider: {provider}") + + @staticmethod + def syntactic_validator(expected_schema: CanonicalFieldEntry, extracted_value: Any) -> list[dict]: + """ + A validator to validate the syntactic correctness of extracted values against Canonical Field Descriptors + """ + errors = [] + + if expected_schema.data_type == Datatype.INT: + try: + int(extracted_value) + except (ValueError, TypeError): + errors.append({"data_type_error": f"expected: int, however {extracted_value} is: {type(extracted_value).__name__}"}) + elif expected_schema.data_type == Datatype.STRING: + if not isinstance(extracted_value, str): + errors.append({"data_type_error": f"expected: string, however {extracted_value} is: {type(extracted_value).__name__}"}) + elif expected_schema.data_type == Datatype.DATE: + if not isinstance(extracted_value, str): + errors.append({"data_type_error": f"expected: date string, however {extracted_value} is: {type(extracted_value).__name__}"}) + + if expected_schema.word_limit and isinstance(extracted_value, str) and len(extracted_value.split(" ")) > expected_schema.word_limit: + errors.append({"word_limit_error": f"extracted value word count {len(extracted_value.split(' '))} exceeds word limit of {expected_schema.word_limit}"}) + + if expected_schema.allowed_values and "values" in expected_schema.allowed_values: + allowed = expected_schema.allowed_values["values"] + if extracted_value not in allowed: + errors.append({"allowed_values_error": f"extracted value: {extracted_value} not in allowed values: {allowed}"}) + + return errors if errors else None + + @classmethod + def semantic_validator(cls, extraction_batch: list[dict], context: str) -> dict: + """ + A validator to validate the semantic correctness of a batch of extracted values against the original context and reasoning. + extraction_batch: list of dicts with: field_name, description, extracted_value, reasoning. + Returns: dict mapping field_name to a string of semantic errors. Empty if valid. + """ + if not extraction_batch: + return {} + + prompt = f""" + You are a semantic validator agent. Your task is to verify if the extracted values make semantic sense + based on the provided source text, field descriptions, and the reasoning used for extraction. + + Batch of Extracted Fields: + {json.dumps(extraction_batch, indent=2)} + + Source Text: + {context} + + Given the source text, evaluate if each extracted value in the batch is correct and semantically sound according to its field description? + Respond in JSON format as a mapping of field_name to validation results. Each validation result must contain "is_valid" (boolean) and "errors" (a string explaining why it is invalid, or empty string if valid). + + Example output: + {{ + "field_name_1": {{ + "is_valid": false, + "errors": "The extracted value mentions a date, but the field description asks for a status." + }}, + "field_name_2": {{ + "is_valid": true, + "errors": "" + }} + }} + """ + + messages = [{"role": "user", "content": prompt}] + + try: + response_content = cls.inference(messages, format="json") + parsed = json.loads(response_content) + + errors_dict = {} + for field_name, result in parsed.items(): + if isinstance(result, dict) and not result.get("is_valid", True): + errors_dict[field_name] = result.get("errors", "Semantic validation failed") + return errors_dict + except Exception as e: + return {item["field_name"]: f"Failed to perform semantic validation: {e}" for item in extraction_batch} + def type_check_all(self): if type(self._transcript_text) is not str: raise TypeError( @@ -48,32 +230,12 @@ def main_loop(self): # self.type_check_all() for field in self._target_fields.keys(): prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") - - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) + parsed_response = self.inference(messages=[{"role": "user", "content": prompt}], format=None) + except Exception as e: + raise RuntimeError(f"Ollama/Inference returned an error: {e}") + self.add_response_to_json(field, parsed_response) print("----------------------------------") @@ -83,6 +245,209 @@ def main_loop(self): return self + def extractor(self, extraction_target: CanonicalSchema): + """ + An extractor agent that extracts target values from the trasncript, and performs basic syntactic validation of the + extracted values, and iteratively redoes extraction for failed fields + """ + MAX_OUTER_RETRIES = 10 + MAX_SYNTACTIC_RETRIES = 5 + + pending_fields = {field.canonical_name: field for field in extraction_target.canonical_fields} + results_dict = {} + semantic_errors = {} + + # Init conversation history array + conversation_history = [ + { + "role": "system", + "content": "You are a report generator agent. You extract objectives from source text into a JSON map with `candidate_value`, `reasoning`, and `confidence`." + } + ] + + outer_iteration = 0 + while pending_fields: + outer_iteration += 1 + + if outer_iteration > MAX_OUTER_RETRIES: + for field_name in list(pending_fields.keys()): + results_dict[field_name] = None + break + + target_fields_info = [] + for field_name, field in pending_fields.items(): + field_info = { + "field_name": field.canonical_name, + "description": field.description, + "expected_data_type": field.data_type, + "word_limit": field.word_limit if field.data_type == "string" else None, + "required": field.required, + "allowed_values": field.allowed_values, + } + target_fields_info.append(field_info) + + prompt = f""" + Extract ONLY the specified target fields from the source text. + Respond in the requested JSON format. + + Target Fields: {target_fields_info} + Source Text: {self._transcript_text} + + JSON Format Expectation: + {{ + "field_name_1": {{ + "candidate_value": , + "reasoning": , + "confidence": + }} + }} + """ + + conversation_history.append({"role": "user", "content": prompt}) + + try: + response_content = self.inference(conversation_history, format="json") + conversation_history.append({"role": "assistant", "content": response_content}) + parsed_response = json.loads(response_content) + except Exception as e: + parsed_response = {} + + syntactically_valid = {} + fields_to_fix = {} + + # Initial syntactic validation + for field_name in list(pending_fields.keys()): + if field_name not in parsed_response: + fields_to_fix[field_name] = { + "extractor_output": None, + "errors": [{"missing_field": "Field was not extracted."}] + } + continue + + list_or_dict = parsed_response[field_name] + if not isinstance(list_or_dict, dict): + fields_to_fix[field_name] = { + "extractor_output": list_or_dict, + "errors": [{"format_error": "Extracted field output must be a dictionary with candidate_value."}] + } + continue + + candidate_value = list_or_dict.get("candidate_value") + errors = self.syntactic_validator(pending_fields[field_name], candidate_value) + + if errors: + fields_to_fix[field_name] = { + "extractor_output": list_or_dict, + "errors": errors + } + else: + syntactically_valid[field_name] = list_or_dict + + # 2. Syntactic Correction Loop + syntactic_retry = 0 + while fields_to_fix: + syntactic_retry += 1 + + if syntactic_retry > MAX_SYNTACTIC_RETRIES: + for field_name in list(fields_to_fix.keys()): + syntactically_valid[field_name] = fields_to_fix[field_name].get("extractor_output") or {"candidate_value": None, "reasoning": "max retries", "confidence": 0} + break + + correction_targets = [] + for field_name, fix_ctx in fields_to_fix.items(): + field_info = { + "field_name": field_name, + "expected_data_type": pending_fields[field_name].data_type, + "previous_invalid_output": fix_ctx.get("extractor_output"), + "syntactic_errors": fix_ctx.get("errors") + } + correction_targets.append(field_info) + + correction_prompt = f""" + You previously extracted values that failed syntactic validation. + Please re-extract ONLY the following fields, correcting the syntactic errors indicated. + Return the exact same JSON format. + + Target Fields to Correct: + {correction_targets} + """ + + conversation_history.append({"role": "user", "content": correction_prompt}) + + try: + response_content = self.inference(conversation_history, format="json") + conversation_history.append({"role": "assistant", "content": response_content}) + correction_response = json.loads(response_content) + except Exception as e: + correction_response = {} + + new_fields_to_fix = {} + for field_name, fix_ctx in fields_to_fix.items(): + if field_name in correction_response and isinstance(correction_response[field_name], dict): + candidate_value = correction_response[field_name].get("candidate_value") + errors = self.syntactic_validator(pending_fields[field_name], candidate_value) + if errors: + new_fields_to_fix[field_name] = { + "extractor_output": correction_response[field_name], + "errors": errors + } + else: + syntactically_valid[field_name] = correction_response[field_name] + else: + new_fields_to_fix[field_name] = fix_ctx + + fields_to_fix = new_fields_to_fix + + # 3. Semantic Validation Loop (Batch and Threshold Filtering) + batch_to_validate = [] + + for field_name, extracted_dict in syntactically_valid.items(): + if not isinstance(extracted_dict, dict): + extracted_dict = {"candidate_value": None, "confidence": 0} + confidence = extracted_dict.get("confidence", 0) + try: + confidence = float(confidence) + except (ValueError, TypeError): + confidence = 0.0 + + if confidence >= 0.90: + results_dict[field_name] = extracted_dict.get("candidate_value") + del pending_fields[field_name] + else: + batch_to_validate.append({ + "field_name": field_name, + "description": pending_fields[field_name].description, + "extracted_value": extracted_dict.get("candidate_value"), + "reasoning": extracted_dict.get("reasoning", "") + }) + + if batch_to_validate: + batch_errors = self.semantic_validator(batch_to_validate, self._transcript_text) + + failed_semantic_names = [] + for item in batch_to_validate: + f_name = item["field_name"] + if f_name in batch_errors: + semantic_errors[f_name] = batch_errors[f_name] + failed_semantic_names.append(f_name) + else: + results_dict[f_name] = item["extracted_value"] + del pending_fields[f_name] + + if failed_semantic_names: + feedback_msg = "The following fields failed semantic validation. Please correct your reasoning and re-extract them accurately based on the source text:\n" + for f_name in failed_semantic_names: + feedback_msg += f"- '{f_name}': {semantic_errors[f_name]}\n" + conversation_history.append({"role": "user", "content": feedback_msg}) + + # 4. End of iteration. If pending_fields is empty, loop breaks. + + # Store results for existing class dependencies + for field, value in results_dict.items(): + self.add_response_to_json(field, str(value)) + + return results_dict + def add_response_to_json(self, field, value): """ this method adds the following value under the specified field,