Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@
SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")


class TDMissingMarker(str):
"""
Custom Typed Dict missing marker to indicate values that are in the annotations but not present at runtime.
We are using a custom subclass string type to be able to differentiate them when creating schemas.
See `py_avro_schema._schemas.TypedDictSchema._record_field` and `UnionSchema._validate_union`
"""

...


TD_MISSING_MARKER = TDMissingMarker("__td_missing__")


class TypeNotSupportedError(TypeError):
"""Error raised when a Avro schema cannot be generated for a given Python type"""

Expand Down Expand Up @@ -312,20 +325,14 @@ def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType:
if fullname in names:
return fullname
names.append(fullname)

fields = [
{"name": REF_ID_KEY, "type": ["null", "long"], "default": None},
{"name": REF_DATA_KEY, "type": inner_schema},
]
if Option.ADD_RUNTIME_TYPE_FIELD in self.options:
fields.append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]})

record_schema = {
"type": "record",
"name": record_name,
"fields": fields,
"fields": [
{"name": REF_ID_KEY, "type": ["null", "long"], "default": None},
{"name": REF_DATA_KEY, "type": inner_schema},
],
}

if self.namespace:
record_schema["namespace"] = self.namespace
return record_schema
Expand Down Expand Up @@ -864,8 +871,34 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o
super().__init__(py_type, namespace=namespace, options=options)
py_type = _type_from_annotated(py_type)
args = get_args(py_type)
self._validate_union(args)
self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args]

@staticmethod
def _validate_union(args: tuple[Any, ...]) -> None:
"""
Validate that the arguments of the Union are possible to deal with. At runtime, we cannot get the runtime type
of TypedDict instances, as they are just regular dicts.
Same for sequences like List and Set, we would have to scan them to know all the runtime types of the values
they contain.
:param args: list of types of the Union
:return: None
:raises: TypeError if the Union types are invalid
"""
if type(None) not in args and TDMissingMarker not in args:
if any(
# Enum is treated as a Sequence
not EnumSchema.handles_type(arg)
and (
is_typeddict(arg)
or SequenceSchema.handles_type(arg)
or DictSchema.handles_type(arg)
or SetSchema.handles_type(arg)
)
for arg in args
):
raise TypeError(f"Union of types {args} is not supported. Python cannot detect proper type at runtime")

def data(self, names: NamesType) -> JSONType:
"""Return the schema data"""
# Render the item schemas
Expand Down Expand Up @@ -1302,23 +1335,13 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti
self.py_fields: list[tuple[str, type]] = []
for k, v in type_hints.items():
self.py_fields.append((k, v))
# We store __init__ parameters with default values. They can be used as defaults for the record.
self.signature_fields = {
param.name: (param.annotation, param.default)
for param in list(inspect.signature(py_type.__init__).parameters.values())[1:]
if param.default is not inspect._empty
}
self.record_fields = [self._record_field(field) for field in self.py_fields]

def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
"""Return an Avro record field object for a given Python instance attribute"""
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
name = py_field[0]
default = dataclasses.MISSING
if field := self.signature_fields.get(name):
_annotation, _default = field
if actual_type is _annotation:
default = _default or dataclasses.MISSING
field_obj = RecordField(
py_type=actual_type,
name=name,
Expand Down Expand Up @@ -1370,15 +1393,15 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
# To do that, we extend the original type with str. We will later add a special string
# (e.g., __td_missing__) as a marker at deserialization time.
actual_type = Union[actual_type, str] # type: ignore
actual_type = Union[actual_type, TDMissingMarker] # type: ignore
if _is_optional(actual_type):
# Note: this works since this schema does not implement `make_default` and the base implementation
# simply return the provided type (None in this case).
default = "__td_missing__" # type: ignore
default = TD_MISSING_MARKER # type: ignore
elif _is_not_required(actual_type):
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
# Similarly as above, we extend the wrapped type with string.
actual_type = Union[_unwrap_not_required(actual_type), str] # type: ignore
actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore

field_obj = RecordField(
py_type=actual_type,
Expand Down
4 changes: 3 additions & 1 deletion src/py_avro_schema/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import py_avro_schema._schemas


def assert_schema(py_type: Type, expected_schema: Union[str, Dict[str, str], List[str]], **kwargs) -> None:
def assert_schema(
py_type: Type, expected_schema: Union[str, Dict[str, str], List[str | Dict[str, str]]], **kwargs
) -> None:
"""Test that the given Python type results in the correct Avro schema"""
if not kwargs.pop("do_auto_namespace", False):
kwargs["options"] = kwargs.get("options", py_avro_schema.Option(0)) | py_avro_schema.Option.NO_AUTO_NAMESPACE
Expand Down
14 changes: 0 additions & 14 deletions tests/test_avro_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,3 @@ class PyType:
],
}
assert_schema(PyType, expected, options=pas.Option.ADD_RUNTIME_TYPE_FIELD)


def test_add_type_field_on_wrapped_record():
py_type = list[str]
expected = {
"type": "record",
"name": "StrList",
"fields": [
{"name": "__id", "type": ["null", "long"], "default": None},
{"name": "__data", "type": {"type": "array", "items": "string"}},
{"name": "_runtime_type", "type": ["null", "string"]},
],
}
assert_schema(py_type, expected, options=pas.Option.WRAP_INTO_RECORDS | pas.Option.ADD_RUNTIME_TYPE_FIELD)
1 change: 0 additions & 1 deletion tests/test_plain_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
{
"name": "country",
"type": "string",
"default": "NLD",
},
{
"name": "latitude",
Expand Down
30 changes: 30 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,36 @@ def test_union_str_str():
assert_schema(py_type, expected)


def test_union_str_list_str_error():
py_type = Union[str, list[str]]
with pytest.raises(TypeError):
py_avro_schema._schemas.schema(py_type)


def test_union_str_dict_str_error():
py_type = Union[str, dict[str, str]]
with pytest.raises(TypeError):
py_avro_schema._schemas.schema(py_type)


def test_union_str_set_str_error():
py_type = Union[str, set[str]]
with pytest.raises(TypeError):
py_avro_schema._schemas.schema(py_type)


def test_union_str_tuple_str_error():
py_type = Union[str, tuple[str, ...]]
with pytest.raises(TypeError):
py_avro_schema._schemas.schema(py_type)


def test_union_str_list_str_with_marker():
py_type = Union[list[str], py_avro_schema._schemas.TDMissingMarker]
expected = [{"items": "string", "type": "array"}, {"namedString": "TDMissingMarker", "type": "string"}]
assert_schema(py_type, expected)


def test_union_str_annotated_str():
py_type = Union[str, Annotated[str, ...]]
expected = "string"
Expand Down
62 changes: 48 additions & 14 deletions tests/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# Copyright 2022 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

from enum import StrEnum
from typing import Annotated, NotRequired, TypedDict
from typing import Annotated, NotRequired, TypedDict, Union

import pytest

import py_avro_schema
import py_avro_schema as pas
from py_avro_schema._alias import Alias, register_type_alias
from py_avro_schema._testing import assert_schema
Expand Down Expand Up @@ -104,27 +118,35 @@ class PyType(TypedDict, total=False):
valid: ValidEnumSymbol | None

expected = {
"type": "record",
"name": "PyType",
"fields": [
{"name": "name", "type": "string"},
{"name": "nickname", "type": ["string", "null"], "default": "__td_missing__"},
{"name": "age", "type": ["string", "long", "null"], "default": "__td_missing__"},
{"name": "name", "type": {"namedString": "TDMissingMarker", "type": "string"}},
{
"name": "invalid",
"type": [{"namedString": "InvalidEnumSymbol", "type": "string"}, "null"],
"default": "__td_missing__",
"name": "nickname",
"type": ["null", {"namedString": "TDMissingMarker", "type": "string"}],
},
{
"default": "__td_missing__",
"name": "age",
"type": [{"namedString": "TDMissingMarker", "type": "string"}, "long", "null"],
},
{
"default": "__td_missing__",
"name": "invalid",
"type": [{"namedString": "TDMissingMarker", "type": "string"}, "null"],
},
{
"default": "__td_missing__",
"name": "valid",
"type": [
"string",
{"namedString": "TDMissingMarker", "type": "string"},
{"default": "valid_val", "name": "ValidEnumSymbol", "symbols": ["valid_val"], "type": "enum"},
"null",
],
},
],
"name": "PyType",
"type": "record",
}
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)

Expand All @@ -137,14 +159,14 @@ class PyType(TypedDict):
nullable_value: NotRequired[str | None]

expected = {
"type": "record",
"name": "PyType",
"fields": [
{"name": "name", "type": "string"},
{"name": "value", "type": "string"},
{"name": "value_int", "type": ["long", "string"]},
{"name": "nullable_value", "type": ["string", "null"]},
{"name": "value", "type": {"namedString": "TDMissingMarker", "type": "string"}},
{"name": "value_int", "type": ["long", {"namedString": "TDMissingMarker", "type": "string"}]},
{"name": "nullable_value", "type": ["null", {"namedString": "TDMissingMarker", "type": "string"}]},
],
"name": "PyType",
"type": "record",
}

assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)
Expand All @@ -170,3 +192,15 @@ class PyType(TypedDict):
],
}
assert_schema(PyType, expected, options=pas.Option.ADD_REFERENCE_ID)


def test_union_typed_dict_error():
class PyType(TypedDict):
var: str

class PyType2(TypedDict):
var: str

py_type = Union[PyType, PyType2]
with pytest.raises(TypeError):
py_avro_schema._schemas.schema(py_type)