Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ class TDMissingMarker(str):
...


class BytesTDMissingMarker(bytes):
"""
Similar to `TDMissingMarker` above, but is required for Typed Dict that have non-required values in `bytes`, as Avro
is unable to manage Unions of `string` and `bytes`.
"""

...


TD_MISSING_MARKER = TDMissingMarker("__td_missing__")
BYTES_TD_MISSING_MARKER = BytesTDMissingMarker(b"__td_missing__")


class TypeNotSupportedError(TypeError):
Expand Down Expand Up @@ -885,6 +895,9 @@ def _validate_union(args: tuple[Any, ...]) -> None:
:return: None
:raises: TypeError if the Union types are invalid
"""
if str in args and bytes in args:
raise TypeError("Avro does not support Union of types bytes and string")

if type(None) not in args and TDMissingMarker not in args:
if any(
# Enum is treated as a Sequence
Expand Down Expand Up @@ -1146,7 +1159,8 @@ def __init__(
if self.default != dataclasses.MISSING:
if isinstance(self.schema, UnionSchema):
self.schema.sort_item_schemas(self.default)
typeguard.check_type("default_value", self.default, self.py_type)
if self.default != TD_MISSING_MARKER:
typeguard.check_type("default_value", self.default, self.py_type)
else:
if Option.DEFAULTS_MANDATORY in self.options:
raise TypeError(f"Default value for field {self} is missing")
Expand Down Expand Up @@ -1387,21 +1401,25 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
"""Return an Avro record field object for a given TypedDict field"""
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])

# Avro does not handle Unions of bytes and string
marker_type = BytesTDMissingMarker if _is_bytes(actual_type) else TDMissingMarker

default = dataclasses.MISSING
if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
# If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
# To do that, we extend the original type with str. We will later add a special string
# (e.g., __td_missing__) as a marker at deserialization time.
actual_type = Union[actual_type, TDMissingMarker] # type: ignore
actual_type = Union[actual_type, marker_type] # type: ignore
if _is_optional(actual_type):
# Note: this works since this schema does not implement `make_default` and the base implementation
# simply return the provided type (None in this case).
# We need to use the string TD_MISSING_MARKER as the schema cannot serialize bytes
default = TD_MISSING_MARKER # type: ignore
elif _is_not_required(actual_type):
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
# Similarly as above, we extend the wrapped type with string.
actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore
actual_type = Union[_unwrap_not_required(actual_type), marker_type] # type: ignore

field_obj = RecordField(
py_type=actual_type,
Expand Down Expand Up @@ -1433,6 +1451,14 @@ def _is_optional(py_type: Type) -> bool:
return False


def _is_bytes(py_type: Type) -> bool:
"""Given a Union of types, checks if bytes is one of those"""
try:
return py_type is bytes or bytes in get_args(py_type)
except Exception:
return False


def _is_not_required(py_type: Type) -> bool:
"""Checks if a type is marked with typing.NotRequired"""
return get_origin(py_type) is NotRequired # noqa
Expand Down
9 changes: 9 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ def test_literal_different_types():
py_avro_schema._schemas.schema(py_type)


def test_union_bytes_string():
py_type = Union[str, bytes]
with pytest.raises(
TypeError,
match=re.escape("Avro does not support Union of types bytes and string"),
):
py_avro_schema._schemas.schema(py_type)


def test_optional_str():
py_type = Optional[str]
expected = ["string", "null"]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class PyType(TypedDict, total=False):
age: int | None
invalid: InvalidEnumSymbol | None
valid: ValidEnumSymbol | None
bytes_data: bytes
bytes_data_nullable: bytes | None

expected = {
"fields": [
Expand Down Expand Up @@ -144,6 +146,11 @@ class PyType(TypedDict, total=False):
"null",
],
},
{
"name": "bytes_data",
"type": "bytes",
},
{"default": "__td_missing__", "name": "bytes_data_nullable", "type": ["bytes", "null"]},
],
"name": "PyType",
"type": "record",
Expand Down
Loading