From d1adeeb55fe81888efacd8cf241ef1a1d2ac74e7 Mon Sep 17 00:00:00 2001 From: jdsika Date: Thu, 2 Apr 2026 17:20:24 +0200 Subject: [PATCH 1/5] feat(generators): add --exclude-external-imports flag Add a dedicated --exclude-external-imports / --no-exclude-external-imports CLI flag to control whether external vocabulary terms are included in generated artifacts when --no-mergeimports is set. Previously external terms leaked into JSON-LD contexts even with --no-mergeimports. The new flag explicitly suppresses terms whose class_uri or slot_uri belong to an imported (external) schema. Tests cover linkml:types built-in import preservation, local file import preservation, and interaction with mergeimports=False. Signed-off-by: jdsika --- .../src/linkml/generators/jsonldcontextgen.py | 54 +++- .../test_generators/test_jsonldcontextgen.py | 291 ++++++++++++++++++ 2 files changed, 342 insertions(+), 3 deletions(-) diff --git a/packages/linkml/src/linkml/generators/jsonldcontextgen.py b/packages/linkml/src/linkml/generators/jsonldcontextgen.py index 60eaa9ffd..306783dd8 100644 --- a/packages/linkml/src/linkml/generators/jsonldcontextgen.py +++ b/packages/linkml/src/linkml/generators/jsonldcontextgen.py @@ -56,8 +56,22 @@ class ContextGenerator(Generator): fix_multivalue_containers: bool | None = False exclude_imports: bool = False """If True, elements from imported schemas won't be included in the generated context""" + exclude_external_imports: bool = False + """If True, elements from URL-based external vocabulary imports are excluded. + + Local file imports and linkml standard imports are kept. This is useful + when extending an external ontology (e.g. W3C Verifiable Credentials) + whose terms are ``@protected`` in their own JSON-LD context — redefining + them locally would violate JSON-LD 1.1 §4.1.11. + + Note: this flag has no effect when ``mergeimports=False`` because + non-local elements are already absent from the visitor iteration + in that mode. + """ _local_classes: set | None = field(default=None, repr=False) _local_slots: set | None = field(default=None, repr=False) + _external_classes: set | None = field(default=None, repr=False) + _external_slots: set | None = field(default=None, repr=False) # Framing (opt-in via CLI flag) emit_frame: bool = False @@ -69,7 +83,7 @@ def __post_init__(self) -> None: super().__post_init__() if self.namespaces is None: raise TypeError("Schema text must be supplied to context generator. Preparsed schema will not work") - if self.exclude_imports: + if self.exclude_imports or self.exclude_external_imports: if self.schemaview: sv = self.schemaview else: @@ -77,8 +91,31 @@ def __post_init__(self) -> None: if isinstance(source, str) and self.base_dir and not Path(source).is_absolute(): source = str(Path(self.base_dir) / source) sv = SchemaView(source, importmap=self.importmap, base_dir=self.base_dir) - self._local_classes = set(sv.all_classes(imports=False).keys()) - self._local_slots = set(sv.all_slots(imports=False).keys()) + if self.exclude_imports: + self._local_classes = set(sv.all_classes(imports=False).keys()) + self._local_slots = set(sv.all_slots(imports=False).keys()) + if self.exclude_external_imports: + self._external_classes, self._external_slots = self._collect_external_elements(sv) + + @staticmethod + def _collect_external_elements(sv: SchemaView) -> tuple[set[str], set[str]]: + """Identify classes and slots from URL-based external vocabulary imports. + + Walks the SchemaView ``schema_map`` (populated by ``imports_closure``) + and collects element names from schemas whose import key starts with + ``http://`` or ``https://``. Local file imports and ``linkml:`` + standard imports are left untouched. + """ + sv.imports_closure() + external_classes: set[str] = set() + external_slots: set[str] = set() + for schema_key, schema_def in sv.schema_map.items(): + if schema_key == sv.schema.name: + continue + if schema_key.startswith("http://") or schema_key.startswith("https://"): + external_classes.update(schema_def.classes.keys()) + external_slots.update(schema_def.slots.keys()) + return external_classes, external_slots def visit_schema(self, base: str | Namespace | None = None, output: str | None = None, **_): # Add any explicitly declared prefixes @@ -194,6 +231,8 @@ def end_schema( def visit_class(self, cls: ClassDefinition) -> bool: if self.exclude_imports and cls.name not in self._local_classes: return False + if self.exclude_external_imports and cls.name in self._external_classes: + return False class_def = {} cn = camelcase(cls.name) @@ -246,6 +285,8 @@ def _literal_coercion_for_ranges(self, ranges: list[str]) -> tuple[bool, str | N def visit_slot(self, aliased_slot_name: str, slot: SlotDefinition) -> None: if self.exclude_imports and slot.name not in self._local_slots: return + if self.exclude_external_imports and slot.name in self._external_slots: + return if slot.identifier: slot_def = "@id" @@ -390,6 +431,13 @@ def serialize( help="Use --exclude-imports to exclude imported elements from the generated JSON-LD context. This is useful when " "extending an ontology whose terms already have context definitions in their own JSON-LD context file.", ) +@click.option( + "--exclude-external-imports/--no-exclude-external-imports", + default=False, + show_default=True, + help="Exclude elements from URL-based external vocabulary imports while keeping local file imports. " + "Useful when extending ontologies (e.g. W3C VC v2) whose terms are @protected in their own JSON-LD context.", +) @click.version_option(__version__, "-V", "--version") def cli(yamlfile, emit_frame, embed_context_in_frame, output, **args): """Generate jsonld @context definition from LinkML model""" diff --git a/tests/linkml/test_generators/test_jsonldcontextgen.py b/tests/linkml/test_generators/test_jsonldcontextgen.py index 6de23347a..ff5b75e66 100644 --- a/tests/linkml/test_generators/test_jsonldcontextgen.py +++ b/tests/linkml/test_generators/test_jsonldcontextgen.py @@ -1,4 +1,5 @@ import json +import textwrap import pytest from click.testing import CliRunner @@ -571,3 +572,293 @@ def test_exclude_imports(input_path): # Imported class and slot must NOT be present assert "BaseClass" not in ctx, "Imported class 'BaseClass' must not appear in exclude-imports context" assert "baseProperty" not in ctx, "Imported slot 'baseProperty' must not appear in exclude-imports context" + + +@pytest.mark.parametrize("mergeimports", [True, False], ids=["merge", "no-merge"]) +def test_exclude_external_imports(tmp_path, mergeimports): + """With --exclude-external-imports, elements from URL-based external + vocabulary imports must not appear in the generated JSON-LD context, + while local file imports and linkml standard imports are kept. + + When a schema imports terms from an external vocabulary (e.g. W3C VC + v2), those terms already have context definitions in their own JSON-LD + context file. Re-defining them in the local context can conflict with + @protected term definitions from the external context (JSON-LD 1.1 + section 4.1.11). + """ + ext_dir = tmp_path / "ext" + ext_dir.mkdir() + (ext_dir / "external_vocab.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/external-vocab + name: external_vocab + default_prefix: ext + prefixes: + linkml: https://w3id.org/linkml/ + ext: https://example.org/external-vocab/ + imports: + - linkml:types + slots: + issuer: + slot_uri: ext:issuer + range: string + validFrom: + slot_uri: ext:validFrom + range: date + classes: + ExternalCredential: + class_uri: ext:ExternalCredential + slots: + - issuer + - validFrom + """), + encoding="utf-8", + ) + + (tmp_path / "main.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/main + name: main + default_prefix: main + prefixes: + linkml: https://w3id.org/linkml/ + main: https://example.org/main/ + ext: https://example.org/external-vocab/ + imports: + - linkml:types + - https://example.org/external-vocab + slots: + localName: + slot_uri: main:localName + range: string + classes: + LocalThing: + class_uri: main:LocalThing + slots: + - localName + """), + encoding="utf-8", + ) + + importmap = {"https://example.org/external-vocab": str(ext_dir / "external_vocab")} + + context_text = ContextGenerator( + str(tmp_path / "main.yaml"), + exclude_external_imports=True, + mergeimports=mergeimports, + importmap=importmap, + base_dir=str(tmp_path), + ).serialize() + context = json.loads(context_text) + ctx = context["@context"] + + # Local terms must be present + assert "localName" in ctx or "local_name" in ctx, ( + f"Local slot missing with mergeimports={mergeimports}, got: {list(ctx.keys())}" + ) + assert "LocalThing" in ctx, f"Local class missing with mergeimports={mergeimports}, got: {list(ctx.keys())}" + + # External vocabulary terms must NOT be present + assert "issuer" not in ctx, f"External slot 'issuer' present with mergeimports={mergeimports}" + assert "validFrom" not in ctx and "valid_from" not in ctx, ( + f"External slot 'validFrom' present with mergeimports={mergeimports}" + ) + assert "ExternalCredential" not in ctx, ( + f"External class 'ExternalCredential' present with mergeimports={mergeimports}" + ) + + +def test_exclude_external_imports_preserves_linkml_types(tmp_path): + """linkml:types (standard library import) must NOT be treated as external. + + The ``linkml:types`` import resolves to a URL internally + (``https://w3id.org/linkml/types``), but it is a standard LinkML import, + not a user-declared external vocabulary. The ``_collect_external_elements`` + method filters by ``schema_key.startswith("http")`` — this test verifies + that linkml built-in types (string, integer, date, etc.) survive the filter. + """ + (tmp_path / "schema.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/test + name: test_linkml_types + default_prefix: ex + prefixes: + linkml: https://w3id.org/linkml/ + ex: https://example.org/ + imports: + - linkml:types + slots: + name: + slot_uri: ex:name + range: string + age: + slot_uri: ex:age + range: integer + classes: + Person: + class_uri: ex:Person + slots: + - name + - age + """), + encoding="utf-8", + ) + + context_text = ContextGenerator( + str(tmp_path / "schema.yaml"), + exclude_external_imports=True, + ).serialize() + ctx = json.loads(context_text)["@context"] + + # Local classes and slots must be present + assert "Person" in ctx, f"Local class 'Person' missing, got: {list(ctx.keys())}" + assert "name" in ctx, f"Local slot 'name' missing, got: {list(ctx.keys())}" + assert "age" in ctx, f"Local slot 'age' missing, got: {list(ctx.keys())}" + + +def test_exclude_external_imports_preserves_local_file_imports(tmp_path): + """Local file imports (non-URL) must be preserved when exclude_external_imports is set. + + Only URL-based imports (http:// or https://) are considered external. + File-path imports between local schemas must remain in the context. + """ + local_dir = tmp_path / "local" + local_dir.mkdir() + (local_dir / "base.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/base + name: base + default_prefix: base + prefixes: + linkml: https://w3id.org/linkml/ + base: https://example.org/base/ + imports: + - linkml:types + slots: + baseField: + slot_uri: base:baseField + range: string + classes: + BaseRecord: + class_uri: base:BaseRecord + slots: + - baseField + """), + encoding="utf-8", + ) + + (tmp_path / "main.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/main + name: main + default_prefix: main + prefixes: + linkml: https://w3id.org/linkml/ + main: https://example.org/main/ + base: https://example.org/base/ + imports: + - linkml:types + - local/base + slots: + localField: + slot_uri: main:localField + range: string + classes: + MainRecord: + class_uri: main:MainRecord + slots: + - localField + """), + encoding="utf-8", + ) + + context_text = ContextGenerator( + str(tmp_path / "main.yaml"), + exclude_external_imports=True, + mergeimports=True, + base_dir=str(tmp_path), + ).serialize() + ctx = json.loads(context_text)["@context"] + + # Local file import terms must be present + assert "MainRecord" in ctx, f"Local class 'MainRecord' missing, got: {list(ctx.keys())}" + assert "BaseRecord" in ctx, f"Local-file-imported class 'BaseRecord' missing, got: {list(ctx.keys())}" + assert "baseField" in ctx or "base_field" in ctx, ( + f"Local-file-imported slot 'baseField' missing, got: {list(ctx.keys())}" + ) + + +def test_exclude_external_imports_works_with_mergeimports_false(tmp_path): + """exclude_external_imports is effective even when mergeimports=False. + + Although mergeimports=False prevents most imported elements from appearing, + external vocabulary elements can still leak into the context via the + schema_map. The exclude_external_imports flag catches these. + """ + ext_dir = tmp_path / "ext" + ext_dir.mkdir() + (ext_dir / "external_vocab.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/external-vocab + name: external_vocab + default_prefix: ext + prefixes: + linkml: https://w3id.org/linkml/ + ext: https://example.org/external-vocab/ + imports: + - linkml:types + slots: + issuer: + slot_uri: ext:issuer + range: string + classes: + ExternalCredential: + class_uri: ext:ExternalCredential + slots: + - issuer + """), + encoding="utf-8", + ) + + (tmp_path / "main.yaml").write_text( + textwrap.dedent("""\ + id: https://example.org/main + name: main + default_prefix: main + prefixes: + linkml: https://w3id.org/linkml/ + main: https://example.org/main/ + ext: https://example.org/external-vocab/ + imports: + - linkml:types + - https://example.org/external-vocab + slots: + localName: + slot_uri: main:localName + range: string + classes: + LocalThing: + class_uri: main:LocalThing + slots: + - localName + """), + encoding="utf-8", + ) + + importmap = {"https://example.org/external-vocab": str(ext_dir / "external_vocab")} + + ctx_text = ContextGenerator( + str(tmp_path / "main.yaml"), + exclude_external_imports=True, + mergeimports=False, + importmap=importmap, + base_dir=str(tmp_path), + ).serialize() + ctx = json.loads(ctx_text)["@context"] + + # Local terms must still be present + assert "LocalThing" in ctx, f"Local class missing, got: {list(ctx.keys())}" + + # External vocabulary terms must be excluded + assert "issuer" not in ctx, "External slot 'issuer' should be excluded with mergeimports=False" + assert "ExternalCredential" not in ctx, "External class should be excluded with mergeimports=False" From 978be6f45f6a09a53e19d496f43dbfc1276dba6d Mon Sep 17 00:00:00 2001 From: jdsika Date: Thu, 2 Apr 2026 22:46:10 +0200 Subject: [PATCH 2/5] docs: correct stale docstring for exclude_external_imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The docstring incorrectly stated that the flag has no effect when mergeimports=False. In reality, external vocabulary elements can still leak into the context via the schema_map even in that mode, and the flag actively catches them — as verified by the existing test test_exclude_external_imports_works_with_mergeimports_false. Replace the inaccurate note with a correct statement. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- packages/linkml/src/linkml/generators/jsonldcontextgen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/linkml/src/linkml/generators/jsonldcontextgen.py b/packages/linkml/src/linkml/generators/jsonldcontextgen.py index 306783dd8..5298a602f 100644 --- a/packages/linkml/src/linkml/generators/jsonldcontextgen.py +++ b/packages/linkml/src/linkml/generators/jsonldcontextgen.py @@ -64,9 +64,9 @@ class ContextGenerator(Generator): whose terms are ``@protected`` in their own JSON-LD context — redefining them locally would violate JSON-LD 1.1 §4.1.11. - Note: this flag has no effect when ``mergeimports=False`` because - non-local elements are already absent from the visitor iteration - in that mode. + This flag is effective regardless of the ``mergeimports`` setting: + even with ``mergeimports=False``, external vocabulary elements can + leak into the context via the schema map. """ _local_classes: set | None = field(default=None, repr=False) _local_slots: set | None = field(default=None, repr=False) From 15d2146cb27ceef6f7a6d318a75c01160e71bb39 Mon Sep 17 00:00:00 2001 From: Silvano Cirujano Cuesta Date: Thu, 2 Apr 2026 22:46:13 +0200 Subject: [PATCH 3/5] Identifiers cannot be null --- .../test_generators/input/identifier.yaml | 17 +++++++++++++++ .../test_generators/test_jsonschemagen.py | 21 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 tests/linkml/test_generators/input/identifier.yaml diff --git a/tests/linkml/test_generators/input/identifier.yaml b/tests/linkml/test_generators/input/identifier.yaml new file mode 100644 index 000000000..064877087 --- /dev/null +++ b/tests/linkml/test_generators/input/identifier.yaml @@ -0,0 +1,17 @@ +id: https://example.org/nullable-id +name: nullable-id +prefixes: + linkml: https://w3id.org/linkml/ + ni: https://example.org/nullable-id/ +imports: + - linkml:types +default_range: string + +classes: + MyClass: + attributes: + id: + identifier: true + needed: + required: true + name: diff --git a/tests/linkml/test_generators/test_jsonschemagen.py b/tests/linkml/test_generators/test_jsonschemagen.py index 468880423..ca1654771 100644 --- a/tests/linkml/test_generators/test_jsonschemagen.py +++ b/tests/linkml/test_generators/test_jsonschemagen.py @@ -316,6 +316,27 @@ def test_slot_title_from_title_slot(subtests, input_path): external_file_test(subtests, input_path("jsonschema_slot_title_from_title.yaml"), {"title_from": "title"}) +@pytest.mark.xfail(reason="identifier slots incorrectly allow null (#2448)", strict=True) +@pytest.mark.parametrize("not_closed", [True, False]) +def test_slot_identifier_non_nullability(input_path, not_closed): + """ + Identifier slots are not allowed to be "null" + + References: + - https://github.com/linkml/linkml/issues/2448 + """ + schema = input_path("identifier.yaml") + generator = JsonSchemaGenerator(schema, mergeimports=True, not_closed=not_closed) + generated = json.loads(generator.serialize()) + key = "id" + for cls in ["MyClass"]: + id = generated["$defs"][cls]["properties"][key] + if "type" in id: + assert "null" not in id["type"], f"{key} does not allow null" + elif "anyOf" in id: + assert {"type": "null"} not in id["anyOf"], f"{key} does not allow null" + + @pytest.mark.parametrize("not_closed", [True, False]) def test_slot_not_required_nullability(input_path, not_closed): """ From 4d45213909faea3162bb47139b9ab283aa030af9 Mon Sep 17 00:00:00 2001 From: Kevin Schaper Date: Thu, 2 Apr 2026 14:39:22 -0700 Subject: [PATCH 4/5] Handle invalid Python class names in pydanticgen using class alias (#2534) --- .../generators/pydanticgen/pydanticgen.py | 71 ++++++++---- .../test_pydanticgen_special_chars.py | 106 +++++++++++++++++- 2 files changed, 156 insertions(+), 21 deletions(-) diff --git a/packages/linkml/src/linkml/generators/pydanticgen/pydanticgen.py b/packages/linkml/src/linkml/generators/pydanticgen/pydanticgen.py index bfc00ea99..0bc98458b 100644 --- a/packages/linkml/src/linkml/generators/pydanticgen/pydanticgen.py +++ b/packages/linkml/src/linkml/generators/pydanticgen/pydanticgen.py @@ -177,6 +177,11 @@ def make_valid_python_identifier(name: str) -> str: return identifier +def _is_valid_python_name(name: str) -> bool: + """Check if a string is a valid Python identifier and not a keyword.""" + return name.isidentifier() and not keyword.iskeyword(name) + + @dataclass class PydanticGenerator(OOCodeGenerator, LifecycleMixin): """ @@ -475,9 +480,10 @@ def generate_class(self, cls: ClassDefinition) -> ClassResult: if cls.union_of: return self._generate_union_class(cls) + class_python_name = self._get_class_python_name(cls.name) pyclass = PydanticClass( - name=camelcase(cls.name), - bases=self.class_bases.get(camelcase(cls.name), PydanticBaseModel.default_name), + name=class_python_name, + bases=self.class_bases.get(class_python_name, PydanticBaseModel.default_name), description=cls.description.replace('"', '\\"') if cls.description is not None else None, ) @@ -537,14 +543,14 @@ def _generate_union_class(self, cls: ClassDefinition) -> ClassResult: ) # Get the union types with string quotes to handle forward references - union_types = [f'"{camelcase(union_cls)}"' for union_cls in cls.union_of] + union_types = [f'"{self._get_class_python_name(union_cls)}"' for union_cls in cls.union_of] union_type_str = f"Union[{', '.join(union_types)}]" # Create a type alias instead of a class # Sanitize description for single-line comment (replace newlines with spaces) description = cls.description.replace("\n", " ").strip() if cls.description else None pyclass = PydanticClass( - name=camelcase(cls.name), + name=self._get_class_python_name(cls.name), bases=[], # Empty list for type aliases description=description, is_type_alias=True, @@ -581,7 +587,7 @@ def generate_slot(self, slot: SlotDefinition, cls: ClassDefinition) -> SlotResul del slot_args["alias"] slot_args["description"] = slot.description.replace('"', '\\"') if slot.description is not None else None - predef = self.predefined_slot_values.get(camelcase(cls.name), {}).get(slot.name, None) + predef = self.predefined_slot_values.get(self._get_class_python_name(cls.name), {}).get(slot.name, None) if predef is not None: slot_args["predefined"] = str(predef) @@ -658,21 +664,19 @@ def predefined_slot_values(self) -> dict[str, dict[str, str]]: ifabsent_processor = PydanticIfAbsentProcessor(sv) slot_values = defaultdict(dict) for class_def in sv.all_classes().values(): + class_python_name = self._get_class_python_name(class_def.name) for slot_name in sv.class_slots(class_def.name): slot = sv.induced_slot(slot_name, class_def.name) if slot.designates_type: target_value = get_type_designator_value(sv, slot, class_def) - slot_values[camelcase(class_def.name)][slot.name] = f'"{target_value}"' + slot_values[class_python_name][slot.name] = f'"{target_value}"' if slot.multivalued: - slot_values[camelcase(class_def.name)][slot.name] = ( - "[" + slot_values[camelcase(class_def.name)][slot.name] + "]" + slot_values[class_python_name][slot.name] = ( + "[" + slot_values[class_python_name][slot.name] + "]" ) - slot_values[camelcase(class_def.name)][slot.name] = slot_values[camelcase(class_def.name)][ - slot.name - ] elif slot.ifabsent is not None: value = ifabsent_processor.process_slot(slot, class_def) - slot_values[camelcase(class_def.name)][slot.name] = value + slot_values[class_python_name][slot.name] = value self._predefined_slot_values = slot_values @@ -690,19 +694,46 @@ def class_bases(self) -> dict[str, list[str]]: for class_def in sv.all_classes().values(): class_parents = [] if class_def.is_a: - class_parents.append(camelcase(class_def.is_a)) + class_parents.append(self._get_class_python_name(class_def.is_a)) if self.gen_mixin_inheritance and class_def.mixins: - class_parents.extend([camelcase(mixin) for mixin in class_def.mixins]) + class_parents.extend([self._get_class_python_name(mixin) for mixin in class_def.mixins]) if len(class_parents) > 0: # Use the sorted list of classes to order the parent classes, but reversed to match MRO needs class_parents.sort( key=lambda x: self.sorted_class_names.index(x) if x in self.sorted_class_names else -1 ) class_parents.reverse() - parents[camelcase(class_def.name)] = class_parents + parents[self._get_class_python_name(class_def.name)] = class_parents self._class_bases = parents return self._class_bases + def _get_class_python_name(self, class_name: str) -> str: + """ + Get a valid Python class name for a LinkML class. + + Tries ``camelcase(name)`` first. If that is not a valid Python identifier, + falls back to ``camelcase(alias)`` when the class defines one. Raises + :class:`ValueError` if neither yields a valid identifier. + """ + python_name = camelcase(class_name) + if _is_valid_python_name(python_name): + return python_name + + class_def = self.schemaview.get_class(class_name) + if class_def and class_def.alias: + alias_name = camelcase(class_def.alias) + if _is_valid_python_name(alias_name): + return alias_name + raise ValueError( + f"Class '{class_name}' has alias '{class_def.alias}' but " + f"'{alias_name}' is not a valid Python identifier" + ) + + raise ValueError( + f"Class name '{class_name}' (Python: '{python_name}') is not a valid Python identifier. " + "Consider providing a class alias that is a valid Python identifier." + ) + def get_mixin_identifier_range(self, mixin) -> str: sv = self.schemaview id_ranges = list( @@ -738,9 +769,10 @@ def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: len([x for x in sv.class_induced_slots(slot_range) if x.designates_type]) > 0 and len(sv.class_descendants(slot_range)) > 1 ): - return "Union[" + ",".join([camelcase(c) for c in sv.class_descendants(slot_range)]) + "]" + descendants = [self._get_class_python_name(c) for c in sv.class_descendants(slot_range)] + return "Union[" + ",".join(descendants) + "]" else: - return f"{camelcase(slot_range)}" + return f"{self._get_class_python_name(slot_range)}" # For the more difficult cases, set string as the default and attempt to improve it range_cls_identifier_slot_range = "str" @@ -1064,7 +1096,8 @@ def _get_element_import(self, class_name: ElementName) -> Import: schema_name = self.schemaview.element_by_schema_map()[class_name] schema = [s for s in self.schemaview.schema_map.values() if s.name == schema_name][0] module = self.generate_module_import(schema, self.split_context) - return Import(module=module, objects=[ObjectImport(name=camelcase(class_name))], is_schema=True) + python_name = self._get_class_python_name(class_name) + return Import(module=module, objects=[ObjectImport(name=python_name)], is_schema=True) def render(self) -> PydanticModule: """ @@ -1107,7 +1140,7 @@ def render(self) -> PydanticModule: # just swap in typing.Any instead down below source_classes = [c for c in source_classes if c.class_uri != "linkml:Any"] source_classes = self.before_generate_classes(source_classes, sv) - self.sorted_class_names = [camelcase(c.name) for c in source_classes] + self.sorted_class_names = [self._get_class_python_name(c.name) for c in source_classes] for cls in source_classes: cls = self.before_generate_class(cls, sv) result = self.generate_class(cls) diff --git a/tests/linkml/test_generators/test_pydanticgen_special_chars.py b/tests/linkml/test_generators/test_pydanticgen_special_chars.py index a9dbd3bc8..9f82d109c 100644 --- a/tests/linkml/test_generators/test_pydanticgen_special_chars.py +++ b/tests/linkml/test_generators/test_pydanticgen_special_chars.py @@ -1,11 +1,11 @@ """ -Tests for Pydantic generator handling of special characters in field names +Tests for Pydantic generator handling of special characters in field and class names """ import pytest from linkml.generators.pydanticgen import PydanticGenerator -from linkml.generators.pydanticgen.pydanticgen import make_valid_python_identifier +from linkml.generators.pydanticgen.pydanticgen import _is_valid_python_name, make_valid_python_identifier from linkml.validator import Validator from linkml.validator.plugins import PydanticValidationPlugin @@ -131,3 +131,105 @@ def test_validation_with_python_field_names(): validator = Validator(schema=schema_text, validation_plugins=[PydanticValidationPlugin()]) report = validator.validate(test_data) assert len(report.results) == 0, f"Validation failed: {report.results}" + + +# --- Tests for _is_valid_python_name --- + + +@pytest.mark.parametrize( + "name, expected", + [ + ("person", True), + ("MyClass", True), + ("_private", True), + ("3DModel", False), + ("Per-son", False), + ("Per!son", False), + ("def", False), + ("class", False), + ("in", False), + ], +) +def test_is_valid_python_name(name: str, expected: bool): + assert _is_valid_python_name(name) is expected + + +# --- Tests for invalid class names --- + + +@pytest.mark.parametrize("class_name", ["3DModel", "Per-son", "Per!son"]) +def test_invalid_class_name_without_alias_raises(class_name): + """Classes with invalid Python names and no alias should raise ValueError.""" + schema_text = f""" + id: test_schema + name: test_schema + imports: + - linkml:types + default_range: string + + classes: + {class_name}: + attributes: + name: + range: string + """ + with pytest.raises(ValueError, match="not a valid Python identifier"): + PydanticGenerator(schema_text).serialize() + + +@pytest.mark.parametrize( + "class_name, class_alias", + [ + ("3DModel", "ThreeDModel"), + ("Per-son", "Person"), + ("my-class", "MyClass"), + ], +) +def test_class_name_with_alias(class_name, class_alias): + """Classes with invalid Python names but valid aliases should use the alias.""" + schema_text = f""" + id: test_schema + name: test_schema + imports: + - linkml:types + default_range: string + + classes: + {class_name}: + alias: {class_alias} + attributes: + name: + range: string + """ + generator = PydanticGenerator(schema_text) + code = generator.serialize() + # The generated code should compile and use the alias as the class name + compile(code, "test", "exec") + assert f"class {class_alias}" in code + + +@pytest.mark.parametrize( + "class_name, bad_alias", + [ + ("3DModel", "3DAlias"), + ("Per-son", "Per-son-alias"), + ], +) +def test_class_name_with_invalid_alias_raises(class_name, bad_alias): + """Classes with invalid names AND invalid aliases should raise ValueError.""" + schema_text = f""" + id: test_schema + name: test_schema + imports: + - linkml:types + default_range: string + + classes: + {class_name}: + alias: {bad_alias} + attributes: + name: + range: string + """ + with pytest.raises(ValueError, match="not a valid Python identifier"): + PydanticGenerator(schema_text).serialize() From 820b2473d94d43646fc96f4ad5dd42eb86be3bfa Mon Sep 17 00:00:00 2001 From: Kevin Schaper Date: Thu, 2 Apr 2026 15:00:39 -0700 Subject: [PATCH 5/5] feat(sqla): add SQLAlchemy 2.x declarative code generation --- .../sqlalchemy_declarative_2x_template.py | 94 +++ .../src/linkml/generators/sqlalchemygen.py | 33 +- .../src/linkml/generators/sqltablegen.py | 11 + pyproject.toml | 8 + .../golden/personinfo_sqla_2x.py | 722 ++++++++++++++++++ .../test_generators/test_sqlalchemygen.py | 160 +++- 6 files changed, 1019 insertions(+), 9 deletions(-) create mode 100644 packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_2x_template.py create mode 100644 tests/linkml/test_generators/golden/personinfo_sqla_2x.py diff --git a/packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_2x_template.py b/packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_2x_template.py new file mode 100644 index 000000000..21738d360 --- /dev/null +++ b/packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_2x_template.py @@ -0,0 +1,94 @@ +sqlalchemy_declarative_2x_template_str = """\ +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal + +from sqlalchemy import ( + Boolean, + Date, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + Numeric, + Text, + Time, +) +from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +class Base(DeclarativeBase): + pass + + +metadata = Base.metadata +{% for c in classes %} + + +class {{ classname(c.name) }}({% if c.is_a %}{{ classname(c.is_a) }}{% else %}Base{% endif %}): + \"\"\" +{% if c.description %} + {{ c.description }} +{% else %} + {{ c.alias }} +{% endif %} + \"\"\" + + __tablename__ = "{{ c.name }}" + +{% for s in c.attributes.values() %} +{% set pytype = python_type(s.annotations['sql_type'].value) %} +{% if 'primary_key' in s.annotations %} + {{ s.alias }}: Mapped[{{ pytype }}] = mapped_column({{ s.annotations['sql_type'].value }} + {%- if 'foreign_key' in s.annotations %}, ForeignKey("{{ s.annotations['foreign_key'].value }}"){% endif -%} + , primary_key=True + {%- if 'autoincrement' in s.annotations %}, autoincrement=True{% endif -%} + ) +{% elif 'required' in s.annotations %} + {{ s.alias }}: Mapped[{{ pytype }}] = mapped_column({{ s.annotations['sql_type'].value }} + {%- if 'foreign_key' in s.annotations %}, ForeignKey("{{ s.annotations['foreign_key'].value }}"){% endif -%} + ) +{% else %} + {{ s.alias }}: Mapped[{{ pytype }} | None] = mapped_column({{ s.annotations['sql_type'].value }} + {%- if 'foreign_key' in s.annotations %}, ForeignKey("{{ s.annotations['foreign_key'].value }}"){% endif -%} + ) +{% endif %} +{% if 'foreign_key' in s.annotations and 'original_slot' in s.annotations %} + {{ s.annotations['original_slot'].value }}: Mapped[{{ classname(s.range) }} | None] = relationship(foreign_keys=[{{ s.alias }}]) +{% endif %} +{% endfor %} +{% for mapping in backrefs[c.name] %} +{% if mapping.mapping_type == "ManyToMany" %} + + # ManyToMany + {{ mapping.source_slot }}: Mapped[list[{{ classname(mapping.target_class) }}]] = relationship(secondary="{{ mapping.join_class }}") +{% elif mapping.mapping_type == "MultivaluedScalar" %} + + {{ mapping.source_slot }}_rel: Mapped[list[{{ classname(mapping.join_class) }}]] = relationship() + {{ mapping.source_slot }}: AssociationProxy[list[str]] = association_proxy( + "{{ mapping.source_slot }}_rel", + "{{ mapping.target_slot }}", + creator=lambda x_: {{ classname(mapping.join_class) }}({{ mapping.target_slot }}=x_), + ) +{% else %} + + # One-To-Many: {{ mapping }} + {{ mapping.source_slot }}: Mapped[list[{{ classname(mapping.target_class) }}]] = relationship(foreign_keys="[{{ mapping.target_class }}.{{ mapping.target_slot }}]") +{% endif %} +{% endfor %} + + def __repr__(self): + return f"{{ c.name }}( + {%- for s in c.attributes.values() -%} + {{ s.alias }}={self.{{ s.alias }}}, + {%- endfor -%} + )" +{% if c.is_a or c.mixins %} + + __mapper_args__ = {"concrete": True} +{% endif %} +{% endfor %} +""" diff --git a/packages/linkml/src/linkml/generators/sqlalchemygen.py b/packages/linkml/src/linkml/generators/sqlalchemygen.py index 8efab64ba..dc77e3c3f 100644 --- a/packages/linkml/src/linkml/generators/sqlalchemygen.py +++ b/packages/linkml/src/linkml/generators/sqlalchemygen.py @@ -7,15 +7,16 @@ from types import ModuleType import click -from jinja2 import Template +from jinja2 import Environment, Template from sqlalchemy import Enum from linkml._version import __version__ from linkml.generators.pydanticgen import PydanticGenerator from linkml.generators.pythongen import PythonGenerator +from linkml.generators.sqlalchemy.sqlalchemy_declarative_2x_template import sqlalchemy_declarative_2x_template_str from linkml.generators.sqlalchemy.sqlalchemy_declarative_template import sqlalchemy_declarative_template_str from linkml.generators.sqlalchemy.sqlalchemy_imperative_template import sqlalchemy_imperative_template_str -from linkml.generators.sqltablegen import SQLTableGenerator +from linkml.generators.sqltablegen import SQL_TYPE_TO_PYTHON_TYPE, SQLTableGenerator from linkml.transformers.relmodel_transformer import ForeignKeyPolicy, RelationalModelTransformer from linkml.utils.generator import Generator, shared_arguments from linkml_runtime.linkml_model import Annotation, ClassDefinition, ClassDefinitionName, SchemaDefinition @@ -29,6 +30,7 @@ class TemplateEnum(Enum): DECLARATIVE = "declarative" IMPERATIVE = "imperative" + DECLARATIVE_2X = "declarative_2x" @dataclass @@ -80,9 +82,14 @@ def generate_sqla( template_str = sqlalchemy_imperative_template_str elif template == TemplateEnum.DECLARATIVE: template_str = sqlalchemy_declarative_template_str + elif template == TemplateEnum.DECLARATIVE_2X: + template_str = sqlalchemy_declarative_2x_template_str else: raise Exception(f"Unknown template type: {template}") - template_obj = Template(template_str) + if template == TemplateEnum.DECLARATIVE_2X: + template_obj = Environment(trim_blocks=True, lstrip_blocks=True).from_string(template_str) + else: + template_obj = Template(template_str) if model_path is None: model_path = self.schema.name logger.info(f"Package for dataclasses == {model_path}") @@ -109,6 +116,7 @@ def generate_sqla( no_model_import=no_model_import, is_join_table=lambda c: any(tag for tag in c.annotations.keys() if tag == "linkml:derived_from"), classes=rel_schema_classes_ordered, + python_type=lambda sql_repr: SQL_TYPE_TO_PYTHON_TYPE.get(sql_repr, "str"), ) logger.debug(f"# Generated code:\n{code}") return code @@ -127,7 +135,7 @@ def compile_sqla( """ Generates and compiles SQL Alchemy bindings - - If template is DECLARATIVE, then a single python module with classes is generated + - If template is DECLARATIVE or DECLARATIVE_2X, then a single python module with classes is generated - If template is IMPERATIVE, only mappings are generated - if compile_python_dataclasses is True then a standard datamodel is generated @@ -142,7 +150,7 @@ def compile_sqla( if model_path is None: model_path = self.schema.name - if template == TemplateEnum.DECLARATIVE: + if template in (TemplateEnum.DECLARATIVE, TemplateEnum.DECLARATIVE_2X): sqla_code = self.generate_sqla(model_path=None, no_model_import=True, template=template, **kwargs) return compile_python(sqla_code, package_path=model_path) elif compile_python_dataclasses: @@ -211,16 +219,27 @@ def order_classes_by_hierarchy(sv: SchemaView) -> list[ClassDefinitionName]: show_default=True, help="Emit FK declarations", ) +@click.option( + "--sqla-style", + type=click.Choice(["1", "2"]), + default=None, + help="SQLAlchemy style to generate (1 or 2). Defaults to 1. Only applies in declarative mode.", +) @click.version_option(__version__, "-V", "--version") @click.command(name="sqla") -def cli(yamlfile, declarative, generate_classes, pydantic, use_foreign_keys=True, **args): +def cli(yamlfile, declarative, generate_classes, pydantic, use_foreign_keys=True, sqla_style=None, **args): """Generate SQL DDL representation""" + if sqla_style and not declarative: + raise click.UsageError("--sqla-style only applies in declarative mode (remove --no-declarative)") if pydantic: pygen = PydanticGenerator(yamlfile) print(pygen.serialize()) gen = SQLAlchemyGenerator(yamlfile, **args) if declarative: - t = TemplateEnum.DECLARATIVE + if sqla_style == "2": + t = TemplateEnum.DECLARATIVE_2X + else: + t = TemplateEnum.DECLARATIVE else: t = TemplateEnum.IMPERATIVE if use_foreign_keys: diff --git a/packages/linkml/src/linkml/generators/sqltablegen.py b/packages/linkml/src/linkml/generators/sqltablegen.py index e6196b909..cad0ce01d 100644 --- a/packages/linkml/src/linkml/generators/sqltablegen.py +++ b/packages/linkml/src/linkml/generators/sqltablegen.py @@ -63,6 +63,17 @@ class SqlNamingPolicy(Enum): "XSDDate": Date(), } +SQL_TYPE_TO_PYTHON_TYPE: dict[str, str] = { + "Text()": "str", + "Integer()": "int", + "Float()": "float", + "Numeric()": "Decimal", + "Boolean()": "bool", + "Time()": "time", + "DateTime()": "datetime", + "Date()": "date", +} + VARCHAR_REGEX = re.compile(r"VARCHAR2?(\((\d+)\))?") ORACLE_MAX_VARCHAR_LENGTH = 4096 diff --git a/pyproject.toml b/pyproject.toml index 8e9211a9a..fda49b272 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,13 +132,21 @@ select = [ "UP", # pyupgrade ] +[tool.ruff.format] +# Golden files are generated output whose formatting is determined by the code generator +exclude = ["tests/**/golden/*.py"] + [tool.ruff.lint.isort] known-first-party = ["linkml", "linkml_runtime"] [tool.ruff.lint.per-file-ignores] # These templates can have long lines "packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_template.py" = ["E501"] +"packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_declarative_2x_template.py" = ["E501"] "packages/linkml/src/linkml/generators/sqlalchemy/sqlalchemy_imperative_template.py" = ["E501"] + +# Golden files are generated output — may have unused imports, long lines, unsorted imports +"tests/**/golden/*.py" = ["E501", "F401", "I001"] "packages/linkml/src/linkml/linter/config/datamodel/config.py" = ["E501", "F401", "I001", "UP007", "UP035", "UP045"] # Auto-generated model files use Optional/Union with string forward references diff --git a/tests/linkml/test_generators/golden/personinfo_sqla_2x.py b/tests/linkml/test_generators/golden/personinfo_sqla_2x.py new file mode 100644 index 000000000..8bfd13f18 --- /dev/null +++ b/tests/linkml/test_generators/golden/personinfo_sqla_2x.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal + +from sqlalchemy import ( + Boolean, + Date, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + Numeric, + Text, + Time, +) +from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +class Base(DeclarativeBase): + pass + + +metadata = Base.metadata + + +class NamedThing(Base): + """ + A generic grouping for any identifiable entity + """ + + __tablename__ = "NamedThing" + + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + def __repr__(self): + return f"NamedThing(id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + +class HasAliases(Base): + """ + A mixin applied to any class that can have aliases/alternateNames + """ + + __tablename__ = "HasAliases" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + + aliases_rel: Mapped[list[HasAliasesAlias]] = relationship() + aliases: AssociationProxy[list[str]] = association_proxy( + "aliases_rel", + "alias", + creator=lambda x_: HasAliasesAlias(alias=x_), + ) + + def __repr__(self): + return f"HasAliases(id={self.id},)" + + +class HasNewsEvents(Base): + """ + None + """ + + __tablename__ = "HasNewsEvents" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + + # ManyToMany + has_news_events: Mapped[list[NewsEvent]] = relationship(secondary="HasNewsEvents_has_news_event") + + def __repr__(self): + return f"HasNewsEvents(id={self.id},)" + + +class Place(Base): + """ + None + """ + + __tablename__ = "Place" + + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + Container_id: Mapped[int | None] = mapped_column(Integer(), ForeignKey("Container.id")) + + aliases_rel: Mapped[list[PlaceAlias]] = relationship() + aliases: AssociationProxy[list[str]] = association_proxy( + "aliases_rel", + "alias", + creator=lambda x_: PlaceAlias(alias=x_), + ) + + def __repr__(self): + return f"Place(id={self.id},name={self.name},depicted_by={self.depicted_by},Container_id={self.Container_id},)" + + +class Address(Base): + """ + None + """ + + __tablename__ = "Address" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + street: Mapped[str | None] = mapped_column(Text()) + city: Mapped[str | None] = mapped_column(Text()) + postal_code: Mapped[str | None] = mapped_column(Text()) + + def __repr__(self): + return f"Address(id={self.id},street={self.street},city={self.city},postal_code={self.postal_code},)" + + +class Event(Base): + """ + None + """ + + __tablename__ = "Event" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + duration: Mapped[float | None] = mapped_column(Float()) + is_current: Mapped[bool | None] = mapped_column(Boolean()) + + def __repr__(self): + return f"Event(id={self.id},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},duration={self.duration},is_current={self.is_current},)" + + +class IntegerPrimaryKeyObject(Base): + """ + None + """ + + __tablename__ = "IntegerPrimaryKeyObject" + + int_id: Mapped[int] = mapped_column(Integer(), primary_key=True) + + def __repr__(self): + return f"IntegerPrimaryKeyObject(int_id={self.int_id},)" + + +class CodeSystem(Base): + """ + None + """ + + __tablename__ = "code system" + + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + + def __repr__(self): + return f"code system(id={self.id},name={self.name},)" + + +class Relationship(Base): + """ + None + """ + + __tablename__ = "Relationship" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + related_to: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + type: Mapped[str | None] = mapped_column(Text()) + + def __repr__(self): + return f"Relationship(id={self.id},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},related_to={self.related_to},type={self.type},)" + + +class WithLocation(Base): + """ + None + """ + + __tablename__ = "WithLocation" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + in_location: Mapped[str | None] = mapped_column(Text(), ForeignKey("Place.id")) + + def __repr__(self): + return f"WithLocation(id={self.id},in_location={self.in_location},)" + + +class Container(Base): + """ + None + """ + + __tablename__ = "Container" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + + # One-To-Many: OneToAnyMapping(source_class='Container', source_slot='persons', mapping_type=None, target_class='Person', target_slot='Container_id', join_class=None, uses_join_table=None, multivalued=False) + persons: Mapped[list[Person]] = relationship(foreign_keys="[Person.Container_id]") + + # One-To-Many: OneToAnyMapping(source_class='Container', source_slot='organizations', mapping_type=None, target_class='Organization', target_slot='Container_id', join_class=None, uses_join_table=None, multivalued=False) + organizations: Mapped[list[Organization]] = relationship(foreign_keys="[Organization.Container_id]") + + # One-To-Many: OneToAnyMapping(source_class='Container', source_slot='places', mapping_type=None, target_class='Place', target_slot='Container_id', join_class=None, uses_join_table=None, multivalued=False) + places: Mapped[list[Place]] = relationship(foreign_keys="[Place.Container_id]") + + def __repr__(self): + return f"Container(id={self.id},)" + + +class PersonAlias(Base): + """ + None + """ + + __tablename__ = "Person_alias" + + Person_id: Mapped[str] = mapped_column(Text(), ForeignKey("Person.id"), primary_key=True) + alias: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"Person_alias(Person_id={self.Person_id},alias={self.alias},)" + + +class PersonHasNewsEvent(Base): + """ + None + """ + + __tablename__ = "Person_has_news_event" + + Person_id: Mapped[str] = mapped_column(Text(), ForeignKey("Person.id"), primary_key=True) + has_news_event_id: Mapped[int] = mapped_column(Integer(), ForeignKey("NewsEvent.id"), primary_key=True) + + def __repr__(self): + return f"Person_has_news_event(Person_id={self.Person_id},has_news_event_id={self.has_news_event_id},)" + + +class HasAliasesAlias(Base): + """ + None + """ + + __tablename__ = "HasAliases_alias" + + HasAliases_id: Mapped[int] = mapped_column(Integer(), ForeignKey("HasAliases.id"), primary_key=True) + alias: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"HasAliases_alias(HasAliases_id={self.HasAliases_id},alias={self.alias},)" + + +class HasNewsEventsHasNewsEvent(Base): + """ + None + """ + + __tablename__ = "HasNewsEvents_has_news_event" + + HasNewsEvents_id: Mapped[int] = mapped_column(Integer(), ForeignKey("HasNewsEvents.id"), primary_key=True) + has_news_event_id: Mapped[int] = mapped_column(Integer(), ForeignKey("NewsEvent.id"), primary_key=True) + + def __repr__(self): + return f"HasNewsEvents_has_news_event(HasNewsEvents_id={self.HasNewsEvents_id},has_news_event_id={self.has_news_event_id},)" + + +class OrganizationCategories(Base): + """ + None + """ + + __tablename__ = "Organization_categories" + + Organization_id: Mapped[str] = mapped_column(Text(), ForeignKey("Organization.id"), primary_key=True) + categories: Mapped[str] = mapped_column(Enum('non profit', 'for profit', 'offshore', 'charity', 'shell company', 'loose organization', name='OrganizationType'), primary_key=True) + + def __repr__(self): + return f"Organization_categories(Organization_id={self.Organization_id},categories={self.categories},)" + + +class OrganizationAlias(Base): + """ + None + """ + + __tablename__ = "Organization_alias" + + Organization_id: Mapped[str] = mapped_column(Text(), ForeignKey("Organization.id"), primary_key=True) + alias: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"Organization_alias(Organization_id={self.Organization_id},alias={self.alias},)" + + +class OrganizationHasNewsEvent(Base): + """ + None + """ + + __tablename__ = "Organization_has_news_event" + + Organization_id: Mapped[str] = mapped_column(Text(), ForeignKey("Organization.id"), primary_key=True) + has_news_event_id: Mapped[int] = mapped_column(Integer(), ForeignKey("NewsEvent.id"), primary_key=True) + + def __repr__(self): + return f"Organization_has_news_event(Organization_id={self.Organization_id},has_news_event_id={self.has_news_event_id},)" + + +class PlaceAlias(Base): + """ + None + """ + + __tablename__ = "Place_alias" + + Place_id: Mapped[str] = mapped_column(Text(), ForeignKey("Place.id"), primary_key=True) + alias: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"Place_alias(Place_id={self.Place_id},alias={self.alias},)" + + +class ConceptMappings(Base): + """ + None + """ + + __tablename__ = "Concept_mappings" + + Concept_id: Mapped[str] = mapped_column(Text(), ForeignKey("Concept.id"), primary_key=True) + mappings: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"Concept_mappings(Concept_id={self.Concept_id},mappings={self.mappings},)" + + +class DiagnosisConceptMappings(Base): + """ + None + """ + + __tablename__ = "DiagnosisConcept_mappings" + + DiagnosisConcept_id: Mapped[str] = mapped_column(Text(), ForeignKey("DiagnosisConcept.id"), primary_key=True) + mappings: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"DiagnosisConcept_mappings(DiagnosisConcept_id={self.DiagnosisConcept_id},mappings={self.mappings},)" + + +class ProcedureConceptMappings(Base): + """ + None + """ + + __tablename__ = "ProcedureConcept_mappings" + + ProcedureConcept_id: Mapped[str] = mapped_column(Text(), ForeignKey("ProcedureConcept.id"), primary_key=True) + mappings: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"ProcedureConcept_mappings(ProcedureConcept_id={self.ProcedureConcept_id},mappings={self.mappings},)" + + +class OperationProcedureConceptMappings(Base): + """ + None + """ + + __tablename__ = "OperationProcedureConcept_mappings" + + OperationProcedureConcept_id: Mapped[str] = mapped_column(Text(), ForeignKey("OperationProcedureConcept.id"), primary_key=True) + mappings: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"OperationProcedureConcept_mappings(OperationProcedureConcept_id={self.OperationProcedureConcept_id},mappings={self.mappings},)" + + +class ImagingProcedureConceptMappings(Base): + """ + None + """ + + __tablename__ = "ImagingProcedureConcept_mappings" + + ImagingProcedureConcept_id: Mapped[str] = mapped_column(Text(), ForeignKey("ImagingProcedureConcept.id"), primary_key=True) + mappings: Mapped[str] = mapped_column(Text(), primary_key=True) + + def __repr__(self): + return f"ImagingProcedureConcept_mappings(ImagingProcedureConcept_id={self.ImagingProcedureConcept_id},mappings={self.mappings},)" + + +class Person(NamedThing): + """ + A person (alive, dead, undead, or fictional). + """ + + __tablename__ = "Person" + + primary_email: Mapped[str | None] = mapped_column(Text()) + birth_date: Mapped[str | None] = mapped_column(Text()) + age: Mapped[int | None] = mapped_column(Integer()) + gender: Mapped[str | None] = mapped_column(Enum('nonbinary man', 'nonbinary woman', 'transgender woman', 'transgender man', 'cisgender man', 'cisgender woman', name='GenderType')) + telephone: Mapped[str | None] = mapped_column(Text()) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + Container_id: Mapped[int | None] = mapped_column(Integer(), ForeignKey("Container.id")) + current_address_id: Mapped[int | None] = mapped_column(Integer(), ForeignKey("Address.id")) + current_address: Mapped[Address | None] = relationship(foreign_keys=[current_address_id]) + + # One-To-Many: OneToAnyMapping(source_class='Person', source_slot='has_employment_history', mapping_type=None, target_class='EmploymentEvent', target_slot='Person_id', join_class=None, uses_join_table=None, multivalued=False) + has_employment_history: Mapped[list[EmploymentEvent]] = relationship(foreign_keys="[EmploymentEvent.Person_id]") + + # One-To-Many: OneToAnyMapping(source_class='Person', source_slot='has_familial_relationships', mapping_type=None, target_class='FamilialRelationship', target_slot='Person_id', join_class=None, uses_join_table=None, multivalued=False) + has_familial_relationships: Mapped[list[FamilialRelationship]] = relationship(foreign_keys="[FamilialRelationship.Person_id]") + + # One-To-Many: OneToAnyMapping(source_class='Person', source_slot='has_interpersonal_relationships', mapping_type=None, target_class='InterPersonalRelationship', target_slot='Person_id', join_class=None, uses_join_table=None, multivalued=False) + has_interpersonal_relationships: Mapped[list[InterPersonalRelationship]] = relationship(foreign_keys="[InterPersonalRelationship.Person_id]") + + # One-To-Many: OneToAnyMapping(source_class='Person', source_slot='has_medical_history', mapping_type=None, target_class='MedicalEvent', target_slot='Person_id', join_class=None, uses_join_table=None, multivalued=False) + has_medical_history: Mapped[list[MedicalEvent]] = relationship(foreign_keys="[MedicalEvent.Person_id]") + + aliases_rel: Mapped[list[PersonAlias]] = relationship() + aliases: AssociationProxy[list[str]] = association_proxy( + "aliases_rel", + "alias", + creator=lambda x_: PersonAlias(alias=x_), + ) + + # ManyToMany + has_news_events: Mapped[list[NewsEvent]] = relationship(secondary="Person_has_news_event") + + def __repr__(self): + return f"Person(primary_email={self.primary_email},birth_date={self.birth_date},age={self.age},gender={self.gender},telephone={self.telephone},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},Container_id={self.Container_id},current_address_id={self.current_address_id},)" + + __mapper_args__ = {"concrete": True} + + +class NewsEvent(Event): + """ + None + """ + + __tablename__ = "NewsEvent" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + headline: Mapped[str | None] = mapped_column(Text()) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + duration: Mapped[float | None] = mapped_column(Float()) + is_current: Mapped[bool | None] = mapped_column(Boolean()) + + def __repr__(self): + return f"NewsEvent(id={self.id},headline={self.headline},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},duration={self.duration},is_current={self.is_current},)" + + __mapper_args__ = {"concrete": True} + + +class Organization(NamedThing): + """ + An organization such as a company or university + """ + + __tablename__ = "Organization" + + mission_statement: Mapped[str | None] = mapped_column(Text()) + founding_date: Mapped[str | None] = mapped_column(Text()) + founding_location: Mapped[str | None] = mapped_column(Text(), ForeignKey("Place.id")) + score: Mapped[Decimal | None] = mapped_column(Numeric()) + min_salary: Mapped[str | None] = mapped_column(Text()) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + Container_id: Mapped[int | None] = mapped_column(Integer(), ForeignKey("Container.id")) + + categories_rel: Mapped[list[OrganizationCategories]] = relationship() + categories: AssociationProxy[list[str]] = association_proxy( + "categories_rel", + "categories", + creator=lambda x_: OrganizationCategories(categories=x_), + ) + + aliases_rel: Mapped[list[OrganizationAlias]] = relationship() + aliases: AssociationProxy[list[str]] = association_proxy( + "aliases_rel", + "alias", + creator=lambda x_: OrganizationAlias(alias=x_), + ) + + # ManyToMany + has_news_events: Mapped[list[NewsEvent]] = relationship(secondary="Organization_has_news_event") + + def __repr__(self): + return f"Organization(mission_statement={self.mission_statement},founding_date={self.founding_date},founding_location={self.founding_location},score={self.score},min_salary={self.min_salary},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},Container_id={self.Container_id},)" + + __mapper_args__ = {"concrete": True} + + +class Concept(NamedThing): + """ + None + """ + + __tablename__ = "Concept" + + code_system: Mapped[str | None] = mapped_column(Text(), ForeignKey("code system.id")) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + mappings_rel: Mapped[list[ConceptMappings]] = relationship() + mappings: AssociationProxy[list[str]] = association_proxy( + "mappings_rel", + "mappings", + creator=lambda x_: ConceptMappings(mappings=x_), + ) + + def __repr__(self): + return f"Concept(code_system={self.code_system},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + __mapper_args__ = {"concrete": True} + + +class FamilialRelationship(Relationship): + """ + None + """ + + __tablename__ = "FamilialRelationship" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + related_to: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + type: Mapped[str] = mapped_column(Enum('SIBLING_OF', 'PARENT_OF', 'CHILD_OF', name='FamilialRelationshipType')) + Person_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + + def __repr__(self): + return f"FamilialRelationship(id={self.id},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},related_to={self.related_to},type={self.type},Person_id={self.Person_id},)" + + __mapper_args__ = {"concrete": True} + + +class InterPersonalRelationship(Relationship): + """ + None + """ + + __tablename__ = "InterPersonalRelationship" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + related_to: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + type: Mapped[str] = mapped_column(Text()) + Person_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + + def __repr__(self): + return f"InterPersonalRelationship(id={self.id},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},related_to={self.related_to},type={self.type},Person_id={self.Person_id},)" + + __mapper_args__ = {"concrete": True} + + +class EmploymentEvent(Event): + """ + None + """ + + __tablename__ = "EmploymentEvent" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + employed_at: Mapped[str | None] = mapped_column(Text(), ForeignKey("Organization.id")) + salary: Mapped[str | None] = mapped_column(Text()) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + duration: Mapped[float | None] = mapped_column(Float()) + is_current: Mapped[bool | None] = mapped_column(Boolean()) + Person_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + + def __repr__(self): + return f"EmploymentEvent(id={self.id},employed_at={self.employed_at},salary={self.salary},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},duration={self.duration},is_current={self.is_current},Person_id={self.Person_id},)" + + __mapper_args__ = {"concrete": True} + + +class MedicalEvent(Event): + """ + None + """ + + __tablename__ = "MedicalEvent" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True, autoincrement=True) + in_location: Mapped[str | None] = mapped_column(Text(), ForeignKey("Place.id")) + started_at_time: Mapped[date | None] = mapped_column(Date()) + ended_at_time: Mapped[date | None] = mapped_column(Date()) + duration: Mapped[float | None] = mapped_column(Float()) + is_current: Mapped[bool | None] = mapped_column(Boolean()) + Person_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("Person.id")) + diagnosis_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("DiagnosisConcept.id")) + diagnosis: Mapped[DiagnosisConcept | None] = relationship(foreign_keys=[diagnosis_id]) + procedure_id: Mapped[str | None] = mapped_column(Text(), ForeignKey("ProcedureConcept.id")) + procedure: Mapped[ProcedureConcept | None] = relationship(foreign_keys=[procedure_id]) + + def __repr__(self): + return f"MedicalEvent(id={self.id},in_location={self.in_location},started_at_time={self.started_at_time},ended_at_time={self.ended_at_time},duration={self.duration},is_current={self.is_current},Person_id={self.Person_id},diagnosis_id={self.diagnosis_id},procedure_id={self.procedure_id},)" + + __mapper_args__ = {"concrete": True} + + +class DiagnosisConcept(Concept): + """ + None + """ + + __tablename__ = "DiagnosisConcept" + + code_system: Mapped[str | None] = mapped_column(Text(), ForeignKey("code system.id")) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + mappings_rel: Mapped[list[DiagnosisConceptMappings]] = relationship() + mappings: AssociationProxy[list[str]] = association_proxy( + "mappings_rel", + "mappings", + creator=lambda x_: DiagnosisConceptMappings(mappings=x_), + ) + + def __repr__(self): + return f"DiagnosisConcept(code_system={self.code_system},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + __mapper_args__ = {"concrete": True} + + +class ProcedureConcept(Concept): + """ + None + """ + + __tablename__ = "ProcedureConcept" + + code_system: Mapped[str | None] = mapped_column(Text(), ForeignKey("code system.id")) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + mappings_rel: Mapped[list[ProcedureConceptMappings]] = relationship() + mappings: AssociationProxy[list[str]] = association_proxy( + "mappings_rel", + "mappings", + creator=lambda x_: ProcedureConceptMappings(mappings=x_), + ) + + def __repr__(self): + return f"ProcedureConcept(code_system={self.code_system},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + __mapper_args__ = {"concrete": True} + + +class OperationProcedureConcept(ProcedureConcept): + """ + None + """ + + __tablename__ = "OperationProcedureConcept" + + code_system: Mapped[str | None] = mapped_column(Text(), ForeignKey("code system.id")) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + mappings_rel: Mapped[list[OperationProcedureConceptMappings]] = relationship() + mappings: AssociationProxy[list[str]] = association_proxy( + "mappings_rel", + "mappings", + creator=lambda x_: OperationProcedureConceptMappings(mappings=x_), + ) + + def __repr__(self): + return f"OperationProcedureConcept(code_system={self.code_system},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + __mapper_args__ = {"concrete": True} + + +class ImagingProcedureConcept(ProcedureConcept): + """ + None + """ + + __tablename__ = "ImagingProcedureConcept" + + code_system: Mapped[str | None] = mapped_column(Text(), ForeignKey("code system.id")) + id: Mapped[str] = mapped_column(Text(), primary_key=True) + name: Mapped[str] = mapped_column(Text()) + description: Mapped[str | None] = mapped_column(Text()) + depicted_by: Mapped[str | None] = mapped_column(Text()) + + mappings_rel: Mapped[list[ImagingProcedureConceptMappings]] = relationship() + mappings: AssociationProxy[list[str]] = association_proxy( + "mappings_rel", + "mappings", + creator=lambda x_: ImagingProcedureConceptMappings(mappings=x_), + ) + + def __repr__(self): + return f"ImagingProcedureConcept(code_system={self.code_system},id={self.id},name={self.name},description={self.description},depicted_by={self.depicted_by},)" + + __mapper_args__ = {"concrete": True} diff --git a/tests/linkml/test_generators/test_sqlalchemygen.py b/tests/linkml/test_generators/test_sqlalchemygen.py index 8efd22766..070cc5269 100644 --- a/tests/linkml/test_generators/test_sqlalchemygen.py +++ b/tests/linkml/test_generators/test_sqlalchemygen.py @@ -1,13 +1,15 @@ import logging import re from collections import Counter +from pathlib import Path import pytest +from click.testing import CliRunner from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from linkml.generators.sqlalchemygen import SQLAlchemyGenerator, TemplateEnum -from linkml.generators.sqltablegen import SQLTableGenerator +from linkml.generators.sqlalchemygen import SQLAlchemyGenerator, TemplateEnum, cli +from linkml.generators.sqltablegen import RANGEMAP, SQL_TYPE_TO_PYTHON_TYPE, SQLTableGenerator from linkml.utils.schema_builder import SchemaBuilder from linkml_runtime.linkml_model import SlotDefinition @@ -21,6 +23,11 @@ def schema(input_path): return str(input_path("personinfo.yaml")) +@pytest.fixture +def schema_path(input_path): + return str(input_path("personinfo.yaml")) + + def test_sqla_basic_imperative(schema): """ Test generation of DDL for imperative mode @@ -317,3 +324,152 @@ def test_sqla_declarative_exec(schema): session.commit() session.close() engine.dispose() + + +# --- 2.x declarative tests --- + + +def test_sql_type_to_python_type_covers_rangemap(): + """Every value in RANGEMAP should have a corresponding Python type.""" + for sql_type in RANGEMAP.values(): + assert repr(sql_type) in SQL_TYPE_TO_PYTHON_TYPE, f"Missing mapping for {repr(sql_type)}" + + +def test_sqla_2x_basic_declarative(schema): + """Test generation of 2.x declarative classes produces valid structure.""" + gen = SQLAlchemyGenerator(schema) + code = gen.generate_sqla(template=TemplateEnum.DECLARATIVE_2X) + + # Should use 2.x patterns + assert "class Base(DeclarativeBase):" in code + assert "Mapped[" in code + assert "mapped_column(" in code + + # Should NOT use 1.x patterns + assert "declarative_base()" not in code + assert "= Column(" not in code + + # Should still have all expected classes + expected_classes = [ + "NamedThing", + "Person", + "Organization", + "Place", + "Address", + "Event", + "Concept", + "DiagnosisConcept", + "ProcedureConcept", + "Relationship", + "FamilialRelationship", + "EmploymentEvent", + "MedicalEvent", + ] + for expected in expected_classes: + assert f"class {expected}(" in code + + +def test_sqla_2x_declarative_exec(schema): + """Full integration test: generate 2.x declarative, create DB, insert, query.""" + engine = create_engine("sqlite://") + ddl = SQLTableGenerator(schema).generate_ddl() + with engine.connect() as connection: + cur = connection.connection.cursor() + cur.executescript(ddl) + + session_class = sessionmaker(bind=engine) + session = session_class() + gen = SQLAlchemyGenerator(schema) + mod = gen.compile_sqla(template=TemplateEnum.DECLARATIVE_2X) + + # Insert and query + session.add(mod.DiagnosisConcept(id="C999", name="rash")) + e1 = mod.MedicalEvent(duration=100.0, diagnosis_id="C999") + dc = mod.DiagnosisConcept(id="C001", name="cough") + e2 = mod.MedicalEvent(duration=200.0, diagnosis=dc) + + p1 = mod.Person(id="P1", name="a b", age=22, has_medical_history=[e1, e2]) + p1.aliases = ["Anne"] + p1.aliases.append("Fred") + p1.has_familial_relationships.append(mod.FamilialRelationship(related_to="P2", type="SIBLING_OF")) + p1.current_address = mod.Address(street="1 a street", city="big city", postal_code="ZZ1 ZZ2") + session.add(p1) + session.add(mod.Person(id="P2", name="Ferdinand Giggleheim", aliases=["Fred"])) + session.commit() + + q = session.query(mod.Person).where(mod.Person.name == p1.name) + persons = q.all() + assert len(persons) == 1 + p1_from_query = persons[0] + assert p1_from_query.age == 22 + assert Counter(p1_from_query.aliases) == Counter(["Anne", "Fred"]) + assert len(p1_from_query.has_medical_history) == 2 + assert any(e for e in p1_from_query.has_medical_history if e.diagnosis_id == "C999") + assert any(e for e in p1_from_query.has_medical_history if e.diagnosis.name == "cough") + assert any(r for r in p1_from_query.has_familial_relationships if r.related_to == "P2") + + session.commit() + session.close() + engine.dispose() + + +def test_2x_mixin(): + """Test that mixins work with 2.x declarative template.""" + b = SchemaBuilder() + b.add_slot(SlotDefinition("ref_to_c1", range="my_class1", multivalued=True)) + b.add_class("my_mixin", slots=["my_mixin_slot"], mixin=True) + b.add_class("my_abstract", slots=["my_abstract_slot"], abstract=True) + b.add_class("my_class1", is_a="my_abstract", mixins=["my_mixin"]) + b.add_class("my_class2", slots=["ref_to_c1"]) + gen = SQLAlchemyGenerator(b.schema) + mod = gen.compile_sqla(template=TemplateEnum.DECLARATIVE_2X) + i1 = mod.MyClass1(my_mixin_slot="v1", my_abstract_slot="v2") + i2 = mod.MyClass2(ref_to_c1=[i1]) + assert i2.ref_to_c1[0] == i1 + + +def test_sqla_2x_enum_columns(schema): + """Enum-backed columns should use str type annotation and Enum() column type.""" + gen = SQLAlchemyGenerator(schema) + code = gen.generate_sqla(template=TemplateEnum.DECLARATIVE_2X) + + # FamilialRelationship.type is enum-backed — should appear as Mapped with Enum + assert "Enum(" in code + # Should compile and run without error + mod = gen.compile_sqla(template=TemplateEnum.DECLARATIVE_2X) + assert hasattr(mod, "FamilialRelationship") + + +def test_sqla_2x_cli(schema_path): + """Smoke test for --sqla-style 2 CLI flag.""" + runner = CliRunner() + result = runner.invoke(cli, [schema_path, "--sqla-style", "2"]) + assert result.exit_code == 0 + assert "class Base(DeclarativeBase):" in result.output + assert "Mapped[" in result.output + + +def test_sqla_style_without_declarative_fails(schema_path): + """--sqla-style with --no-declarative should error.""" + runner = CliRunner() + result = runner.invoke(cli, [schema_path, "--no-declarative", "--sqla-style", "2"]) + assert result.exit_code != 0 + assert "only applies in declarative mode" in result.output + + +GOLDEN_FILE = Path(__file__).parent / "golden" / "personinfo_sqla_2x.py" + + +def test_sqla_2x_golden_file(schema): + """Generated 2.x output should match the checked-in golden file.""" + gen = SQLAlchemyGenerator(schema) + code = gen.generate_sqla(template=TemplateEnum.DECLARATIVE_2X) + if not GOLDEN_FILE.exists(): + GOLDEN_FILE.parent.mkdir(parents=True, exist_ok=True) + GOLDEN_FILE.write_text(code) + pytest.skip("Golden file created — review and commit it, then re-run.") + expected = GOLDEN_FILE.read_text() + assert code == expected, ( + f"Generated output differs from golden file {GOLDEN_FILE}. " + "If the change is intentional, delete the golden file and re-run to regenerate." + )