From 8190aa442b0a6a3d5f3c45fd18bfd801d18f3c4b Mon Sep 17 00:00:00 2001 From: Benjamin Simon Date: Fri, 27 Feb 2026 16:33:51 +0100 Subject: [PATCH 1/2] disallow Union of Typed Dict, dict/sequences --- src/py_avro_schema/_schemas.py | 55 +++++++++++++++++++++++------- src/py_avro_schema/_testing.py | 4 ++- tests/test_plain_class.py | 1 - tests/test_primitives.py | 30 ++++++++++++++++ tests/test_typed_dict.py | 62 ++++++++++++++++++++++++++-------- 5 files changed, 123 insertions(+), 29 deletions(-) diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 7eb04a9..268a1b0 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -77,6 +77,19 @@ SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") +class TDMissingMarker(str): + """ + Custom Typed Dict missing marker to indicate values that are in the annotations but not present at runtime. + We are using a custom subclass string type to be able to differentiate them when creating schemas. + See `py_avro_schema._schemas.TypedDictSchema._record_field` and `UnionSchema._validate_union` + """ + + ... + + +TD_MISSING_MARKER = TDMissingMarker("__td_missing__") + + class TypeNotSupportedError(TypeError): """Error raised when a Avro schema cannot be generated for a given Python type""" @@ -864,8 +877,34 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o super().__init__(py_type, namespace=namespace, options=options) py_type = _type_from_annotated(py_type) args = get_args(py_type) + self._validate_union(args) self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args] + @staticmethod + def _validate_union(args: tuple[Any, ...]) -> None: + """ + Validate that the arguments of the Union are possible to deal with. At runtime, we cannot get the runtime type + of TypedDict instances, as they are just regular dicts. + Same for sequences like List and Set, we would have to scan them to know all the runtime types of the values + they contain. + :param args: list of types of the Union + :return: None + :raises: TypeError if the Union types are invalid + """ + if type(None) not in args and TDMissingMarker not in args: + if any( + # Enum is treated as a Sequence + not EnumSchema.handles_type(arg) + and ( + is_typeddict(arg) + or SequenceSchema.handles_type(arg) + or DictSchema.handles_type(arg) + or SetSchema.handles_type(arg) + ) + for arg in args + ): + raise TypeError(f"Union of types {args} is not supported. Python cannot detect proper type at runtime") + def data(self, names: NamesType) -> JSONType: """Return the schema data""" # Render the item schemas @@ -1302,12 +1341,6 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti self.py_fields: list[tuple[str, type]] = [] for k, v in type_hints.items(): self.py_fields.append((k, v)) - # We store __init__ parameters with default values. They can be used as defaults for the record. - self.signature_fields = { - param.name: (param.annotation, param.default) - for param in list(inspect.signature(py_type.__init__).parameters.values())[1:] - if param.default is not inspect._empty - } self.record_fields = [self._record_field(field) for field in self.py_fields] def _record_field(self, py_field: tuple[str, Type]) -> RecordField: @@ -1315,10 +1348,6 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: aliases, actual_type = get_field_aliases_and_actual_type(py_field[1]) name = py_field[0] default = dataclasses.MISSING - if field := self.signature_fields.get(name): - _annotation, _default = field - if actual_type is _annotation: - default = _default or dataclasses.MISSING field_obj = RecordField( py_type=actual_type, name=name, @@ -1370,15 +1399,15 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: # be able to distinguish between the fields that are missing from the ones that are present but set to None. # To do that, we extend the original type with str. We will later add a special string # (e.g., __td_missing__) as a marker at deserialization time. - actual_type = Union[actual_type, str] # type: ignore + actual_type = Union[actual_type, TDMissingMarker] # type: ignore if _is_optional(actual_type): # Note: this works since this schema does not implement `make_default` and the base implementation # simply return the provided type (None in this case). - default = "__td_missing__" # type: ignore + default = TD_MISSING_MARKER # type: ignore elif _is_not_required(actual_type): # A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False. # Similarly as above, we extend the wrapped type with string. - actual_type = Union[_unwrap_not_required(actual_type), str] # type: ignore + actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore field_obj = RecordField( py_type=actual_type, diff --git a/src/py_avro_schema/_testing.py b/src/py_avro_schema/_testing.py index 45fe4a2..e223c64 100644 --- a/src/py_avro_schema/_testing.py +++ b/src/py_avro_schema/_testing.py @@ -24,7 +24,9 @@ import py_avro_schema._schemas -def assert_schema(py_type: Type, expected_schema: Union[str, Dict[str, str], List[str]], **kwargs) -> None: +def assert_schema( + py_type: Type, expected_schema: Union[str, Dict[str, str], List[str | Dict[str, str]]], **kwargs +) -> None: """Test that the given Python type results in the correct Avro schema""" if not kwargs.pop("do_auto_namespace", False): kwargs["options"] = kwargs.get("options", py_avro_schema.Option(0)) | py_avro_schema.Option.NO_AUTO_NAMESPACE diff --git a/tests/test_plain_class.py b/tests/test_plain_class.py index 7203236..b814547 100644 --- a/tests/test_plain_class.py +++ b/tests/test_plain_class.py @@ -51,7 +51,6 @@ def __init__( { "name": "country", "type": "string", - "default": "NLD", }, { "name": "latitude", diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 0493a67..5c96b3b 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -385,6 +385,36 @@ def test_union_str_str(): assert_schema(py_type, expected) +def test_union_str_list_str_error(): + py_type = Union[str, list[str]] + with pytest.raises(TypeError): + py_avro_schema._schemas.schema(py_type) + + +def test_union_str_dict_str_error(): + py_type = Union[str, dict[str, str]] + with pytest.raises(TypeError): + py_avro_schema._schemas.schema(py_type) + + +def test_union_str_set_str_error(): + py_type = Union[str, set[str]] + with pytest.raises(TypeError): + py_avro_schema._schemas.schema(py_type) + + +def test_union_str_tuple_str_error(): + py_type = Union[str, tuple[str, ...]] + with pytest.raises(TypeError): + py_avro_schema._schemas.schema(py_type) + + +def test_union_str_list_str_with_marker(): + py_type = Union[list[str], py_avro_schema._schemas.TDMissingMarker] + expected = [{"items": "string", "type": "array"}, {"namedString": "TDMissingMarker", "type": "string"}] + assert_schema(py_type, expected) + + def test_union_str_annotated_str(): py_type = Union[str, Annotated[str, ...]] expected = "string" diff --git a/tests/test_typed_dict.py b/tests/test_typed_dict.py index d63ca71..8434108 100644 --- a/tests/test_typed_dict.py +++ b/tests/test_typed_dict.py @@ -1,6 +1,20 @@ +# Copyright 2022 J.P. Morgan Chase & Co. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + from enum import StrEnum -from typing import Annotated, NotRequired, TypedDict +from typing import Annotated, NotRequired, TypedDict, Union + +import pytest +import py_avro_schema import py_avro_schema as pas from py_avro_schema._alias import Alias, register_type_alias from py_avro_schema._testing import assert_schema @@ -104,27 +118,35 @@ class PyType(TypedDict, total=False): valid: ValidEnumSymbol | None expected = { - "type": "record", - "name": "PyType", "fields": [ - {"name": "name", "type": "string"}, - {"name": "nickname", "type": ["string", "null"], "default": "__td_missing__"}, - {"name": "age", "type": ["string", "long", "null"], "default": "__td_missing__"}, + {"name": "name", "type": {"namedString": "TDMissingMarker", "type": "string"}}, { - "name": "invalid", - "type": [{"namedString": "InvalidEnumSymbol", "type": "string"}, "null"], "default": "__td_missing__", + "name": "nickname", + "type": ["null", {"namedString": "TDMissingMarker", "type": "string"}], + }, + { + "default": "__td_missing__", + "name": "age", + "type": [{"namedString": "TDMissingMarker", "type": "string"}, "long", "null"], + }, + { + "default": "__td_missing__", + "name": "invalid", + "type": [{"namedString": "TDMissingMarker", "type": "string"}, "null"], }, { "default": "__td_missing__", "name": "valid", "type": [ - "string", + {"namedString": "TDMissingMarker", "type": "string"}, {"default": "valid_val", "name": "ValidEnumSymbol", "symbols": ["valid_val"], "type": "enum"}, "null", ], }, ], + "name": "PyType", + "type": "record", } assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS) @@ -137,14 +159,14 @@ class PyType(TypedDict): nullable_value: NotRequired[str | None] expected = { - "type": "record", - "name": "PyType", "fields": [ {"name": "name", "type": "string"}, - {"name": "value", "type": "string"}, - {"name": "value_int", "type": ["long", "string"]}, - {"name": "nullable_value", "type": ["string", "null"]}, + {"name": "value", "type": {"namedString": "TDMissingMarker", "type": "string"}}, + {"name": "value_int", "type": ["long", {"namedString": "TDMissingMarker", "type": "string"}]}, + {"name": "nullable_value", "type": ["null", {"namedString": "TDMissingMarker", "type": "string"}]}, ], + "name": "PyType", + "type": "record", } assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS) @@ -170,3 +192,15 @@ class PyType(TypedDict): ], } assert_schema(PyType, expected, options=pas.Option.ADD_REFERENCE_ID) + + +def test_union_typed_dict_error(): + class PyType(TypedDict): + var: str + + class PyType2(TypedDict): + var: str + + py_type = Union[PyType, PyType2] + with pytest.raises(TypeError): + py_avro_schema._schemas.schema(py_type) From 09f294b3e3fa34edf8e13a758fb0bc168f00935a Mon Sep 17 00:00:00 2001 From: Benjamin Simon Date: Fri, 27 Feb 2026 16:35:23 +0100 Subject: [PATCH 2/2] Revert "add runtime type for wrapped records (#23)" This reverts commit 469d815a59c0d3c32c3062760f5219f9fae73605. --- src/py_avro_schema/_schemas.py | 14 ++++---------- tests/test_avro_schema.py | 14 -------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 268a1b0..884c1f9 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -325,20 +325,14 @@ def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType: if fullname in names: return fullname names.append(fullname) - - fields = [ - {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, - {"name": REF_DATA_KEY, "type": inner_schema}, - ] - if Option.ADD_RUNTIME_TYPE_FIELD in self.options: - fields.append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]}) - record_schema = { "type": "record", "name": record_name, - "fields": fields, + "fields": [ + {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, + {"name": REF_DATA_KEY, "type": inner_schema}, + ], } - if self.namespace: record_schema["namespace"] = self.namespace return record_schema diff --git a/tests/test_avro_schema.py b/tests/test_avro_schema.py index 691e990..5bffbcb 100644 --- a/tests/test_avro_schema.py +++ b/tests/test_avro_schema.py @@ -80,17 +80,3 @@ class PyType: ], } assert_schema(PyType, expected, options=pas.Option.ADD_RUNTIME_TYPE_FIELD) - - -def test_add_type_field_on_wrapped_record(): - py_type = list[str] - expected = { - "type": "record", - "name": "StrList", - "fields": [ - {"name": "__id", "type": ["null", "long"], "default": None}, - {"name": "__data", "type": {"type": "array", "items": "string"}}, - {"name": "_runtime_type", "type": ["null", "string"]}, - ], - } - assert_schema(py_type, expected, options=pas.Option.WRAP_INTO_RECORDS | pas.Option.ADD_RUNTIME_TYPE_FIELD)