From 488659a3cfdacca9fd56b09586cddf17af59c927 Mon Sep 17 00:00:00 2001 From: Lukas Bindreiter Date: Fri, 22 Aug 2025 14:41:47 +0200 Subject: [PATCH] Correctly deserialize nested dataclasses in task args in case they are optional --- CHANGELOG.md | 2 + tilebox-workflows/tests/test_task.py | 53 +++++++++++++++++++++ tilebox-workflows/tilebox/workflows/task.py | 35 +++++++++++++- 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbd9d85..a9ab098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `tilebox-workflows`: Registering duplicate task identifiers with a task runner now raises a `ValueError` instead of overwriting the existing task. +- `tilebox-workflows`: Fixed a bug where the `deserialize_task` function would fail to deserialize nested dataclasses or + protobuf messages that are wrapped in an `Optional` or `Annotated` type hint. ## [0.41.0] - 2025-08-01 diff --git a/tilebox-workflows/tests/test_task.py b/tilebox-workflows/tests/test_task.py index 1a77d75..9d63255 100644 --- a/tilebox-workflows/tests/test_task.py +++ b/tilebox-workflows/tests/test_task.py @@ -1,5 +1,6 @@ import json from dataclasses import dataclass +from typing import Annotated import pytest @@ -9,6 +10,7 @@ ExecutionContext, Task, TaskMeta, + _get_deserialization_field_type, deserialize_task, serialize_task, ) @@ -348,3 +350,54 @@ def test_serialize_deserialize_task_nested_protobuf_in_nested_dict() -> None: {"a": {"b": [SampleArgs(some_string="World", some_int=123), SampleArgs(some_string="!", some_int=456)]}}, ) assert deserialize_task(ExampleTaskWithNestedProtobufInNestedDict, serialize_task(task)) == task + + +class ExampleTaskWithOptionalNestedJson(Task): + x: str + optional_args: NestedJson | None = None + + +def test_serialize_deserialize_task_nested_optional_json() -> None: + task = ExampleTaskWithOptionalNestedJson("Hello") + assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task + + task = ExampleTaskWithOptionalNestedJson("Hello", NestedJson(nested_x="World")) + assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task + + +class ExampleTaskWithOptionalNestedProtobuf(Task): + x: str + optional_args: SampleArgs | None = None + + +def test_serialize_deserialize_task_nested_optional_protobuf() -> None: + task = ExampleTaskWithOptionalNestedProtobuf("Hello") + assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task + + task = ExampleTaskWithOptionalNestedProtobuf("Hello", SampleArgs(some_string="World", some_int=123)) + assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task + + +class FieldTypesTest(Task): + field1: str + field2: str | None + field3: NestedJson | None + field4: NestedJson | None + field5: Annotated[NestedJson, "some description"] + field6: Annotated[NestedJson, "some description"] | None + field7: Annotated[NestedJson | None, "some description"] + field8: Annotated[NestedJson | None, "some description"] + field9: Annotated[list[NestedJson] | None, "some description"] + + +def test_get_deserialization_field_type() -> None: + fields = FieldTypesTest.__dataclass_fields__ + assert _get_deserialization_field_type(fields["field1"].type) is str + assert _get_deserialization_field_type(fields["field2"].type) is str + assert _get_deserialization_field_type(fields["field3"].type) is NestedJson + assert _get_deserialization_field_type(fields["field4"].type) is NestedJson + assert _get_deserialization_field_type(fields["field5"].type) is NestedJson + assert _get_deserialization_field_type(fields["field6"].type) is NestedJson + assert _get_deserialization_field_type(fields["field7"].type) is NestedJson + assert _get_deserialization_field_type(fields["field8"].type) is NestedJson + assert _get_deserialization_field_type(fields["field9"].type) == list[NestedJson] diff --git a/tilebox-workflows/tilebox/workflows/task.py b/tilebox-workflows/tilebox/workflows/task.py index 0a17181..d0bc30f 100644 --- a/tilebox-workflows/tilebox/workflows/task.py +++ b/tilebox-workflows/tilebox/workflows/task.py @@ -1,10 +1,12 @@ import contextlib import inspect import json +import typing from abc import ABC, ABCMeta, abstractmethod from base64 import b64decode, b64encode from collections.abc import Sequence from dataclasses import dataclass, field, fields, is_dataclass +from types import NoneType, UnionType from typing import Any, cast, get_args, get_origin # from python 3.11 onwards this is available as typing.dataclass_transform: @@ -350,7 +352,7 @@ def deserialize_task(task_cls: type, task_input: bytes) -> Task: return task_cls() # empty task if len(task_fields) == 1: # if there is only one field, we deserialize it directly - field_type = task_fields[0].type + field_type = _get_deserialization_field_type(task_fields[0].type) # type: ignore[arg-type] if hasattr(field_type, "FromString"): # protobuf message value = field_type.FromString(task_input) # type: ignore[arg-type] else: @@ -372,6 +374,10 @@ def _deserialize_dataclass(cls: type, params: dict[str, Any]) -> Task: def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911 + if value is None: + return None + + field_type = _get_deserialization_field_type(field_type) if hasattr(field_type, "FromString"): return field_type.FromString(b64decode(value)) if is_dataclass(field_type) and isinstance(value, dict): @@ -398,3 +404,30 @@ def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911 return {k: _deserialize_value(type_args[1], v) for k, v in value.items()} return value + + +def _get_deserialization_field_type(field_type: type) -> type: + """ + Get the actual underlying type we want to deserialize a field type annotated as. + + This correctly handles optional and annotated type hints. + + For example, all of the following fields should be deserialized as MyDataclass class + + field1: MyDataclass + field2: MyDataclass | None + field3: Optional[MyDataclass] + field4: Annotated[MyDataclass, "some description"] + field5: Annotated[Optional[MyDataclass], "some description" + """ + origin = typing.get_origin(field_type) + if origin in (typing.Union, UnionType): # handle Optional[type] and 'type | None' + args = typing.get_args(field_type) + if len(args) == 2 and args[-1] == NoneType: + return _get_deserialization_field_type(args[0]) + if origin == typing.Annotated: + args = typing.get_args(field_type) + if len(args) >= 1: + return _get_deserialization_field_type(args[0]) + + return field_type