diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 884c1f9..9f59742 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -87,7 +87,17 @@ class TDMissingMarker(str): ... +class BytesTDMissingMarker(bytes): + """ + Similar to `TDMissingMarker` above, but is required for Typed Dict that have non-required values in `bytes`, as Avro + is unable to manage Unions of `string` and `bytes`. + """ + + ... + + TD_MISSING_MARKER = TDMissingMarker("__td_missing__") +BYTES_TD_MISSING_MARKER = BytesTDMissingMarker(b"__td_missing__") class TypeNotSupportedError(TypeError): @@ -885,6 +895,9 @@ def _validate_union(args: tuple[Any, ...]) -> None: :return: None :raises: TypeError if the Union types are invalid """ + if str in args and bytes in args: + raise TypeError("Avro does not support Union of types bytes and string") + if type(None) not in args and TDMissingMarker not in args: if any( # Enum is treated as a Sequence @@ -1146,7 +1159,8 @@ def __init__( if self.default != dataclasses.MISSING: if isinstance(self.schema, UnionSchema): self.schema.sort_item_schemas(self.default) - typeguard.check_type("default_value", self.default, self.py_type) + if self.default != TD_MISSING_MARKER: + typeguard.check_type("default_value", self.default, self.py_type) else: if Option.DEFAULTS_MANDATORY in self.options: raise TypeError(f"Default value for field {self} is missing") @@ -1387,21 +1401,25 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: """Return an Avro record field object for a given TypedDict field""" aliases, actual_type = get_field_aliases_and_actual_type(py_field[1]) + # Avro does not handle Unions of bytes and string + marker_type = BytesTDMissingMarker if _is_bytes(actual_type) else TDMissingMarker + default = dataclasses.MISSING if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total: # If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to # be able to distinguish between the fields that are missing from the ones that are present but set to None. # To do that, we extend the original type with str. We will later add a special string # (e.g., __td_missing__) as a marker at deserialization time. - actual_type = Union[actual_type, TDMissingMarker] # type: ignore + actual_type = Union[actual_type, marker_type] # type: ignore if _is_optional(actual_type): # Note: this works since this schema does not implement `make_default` and the base implementation # simply return the provided type (None in this case). + # We need to use the string TD_MISSING_MARKER as the schema cannot serialize bytes default = TD_MISSING_MARKER # type: ignore elif _is_not_required(actual_type): # A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False. # Similarly as above, we extend the wrapped type with string. - actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore + actual_type = Union[_unwrap_not_required(actual_type), marker_type] # type: ignore field_obj = RecordField( py_type=actual_type, @@ -1433,6 +1451,14 @@ def _is_optional(py_type: Type) -> bool: return False +def _is_bytes(py_type: Type) -> bool: + """Given a Union of types, checks if bytes is one of those""" + try: + return py_type is bytes or bytes in get_args(py_type) + except Exception: + return False + + def _is_not_required(py_type: Type) -> bool: """Checks if a type is marked with typing.NotRequired""" return get_origin(py_type) is NotRequired # noqa diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 5c96b3b..925a11f 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -430,6 +430,15 @@ def test_literal_different_types(): py_avro_schema._schemas.schema(py_type) +def test_union_bytes_string(): + py_type = Union[str, bytes] + with pytest.raises( + TypeError, + match=re.escape("Avro does not support Union of types bytes and string"), + ): + py_avro_schema._schemas.schema(py_type) + + def test_optional_str(): py_type = Optional[str] expected = ["string", "null"] diff --git a/tests/test_typed_dict.py b/tests/test_typed_dict.py index 8434108..3ce9b73 100644 --- a/tests/test_typed_dict.py +++ b/tests/test_typed_dict.py @@ -116,6 +116,8 @@ class PyType(TypedDict, total=False): age: int | None invalid: InvalidEnumSymbol | None valid: ValidEnumSymbol | None + bytes_data: bytes + bytes_data_nullable: bytes | None expected = { "fields": [ @@ -144,6 +146,11 @@ class PyType(TypedDict, total=False): "null", ], }, + { + "name": "bytes_data", + "type": "bytes", + }, + {"default": "__td_missing__", "name": "bytes_data_nullable", "type": ["bytes", "null"]}, ], "name": "PyType", "type": "record",