Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `tilebox-workflows`: Registering duplicate task identifiers with a task runner now raises a `ValueError` instead of
overwriting the existing task.
- `tilebox-workflows`: Fixed a bug where the `deserialize_task` function would fail to deserialize nested dataclasses or
protobuf messages that are wrapped in an `Optional` or `Annotated` type hint.

## [0.41.0] - 2025-08-01

Expand Down
53 changes: 53 additions & 0 deletions tilebox-workflows/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from dataclasses import dataclass
from typing import Annotated

import pytest

Expand All @@ -9,6 +10,7 @@
ExecutionContext,
Task,
TaskMeta,
_get_deserialization_field_type,
deserialize_task,
serialize_task,
)
Expand Down Expand Up @@ -348,3 +350,54 @@ def test_serialize_deserialize_task_nested_protobuf_in_nested_dict() -> None:
{"a": {"b": [SampleArgs(some_string="World", some_int=123), SampleArgs(some_string="!", some_int=456)]}},
)
assert deserialize_task(ExampleTaskWithNestedProtobufInNestedDict, serialize_task(task)) == task


class ExampleTaskWithOptionalNestedJson(Task):
x: str
optional_args: NestedJson | None = None


def test_serialize_deserialize_task_nested_optional_json() -> None:
task = ExampleTaskWithOptionalNestedJson("Hello")
assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task

task = ExampleTaskWithOptionalNestedJson("Hello", NestedJson(nested_x="World"))
assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task


class ExampleTaskWithOptionalNestedProtobuf(Task):
x: str
optional_args: SampleArgs | None = None


def test_serialize_deserialize_task_nested_optional_protobuf() -> None:
task = ExampleTaskWithOptionalNestedProtobuf("Hello")
assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task

task = ExampleTaskWithOptionalNestedProtobuf("Hello", SampleArgs(some_string="World", some_int=123))
assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task


class FieldTypesTest(Task):
field1: str
field2: str | None
field3: NestedJson | None
field4: NestedJson | None
field5: Annotated[NestedJson, "some description"]
field6: Annotated[NestedJson, "some description"] | None
field7: Annotated[NestedJson | None, "some description"]
field8: Annotated[NestedJson | None, "some description"]
field9: Annotated[list[NestedJson] | None, "some description"]


def test_get_deserialization_field_type() -> None:
fields = FieldTypesTest.__dataclass_fields__
assert _get_deserialization_field_type(fields["field1"].type) is str
assert _get_deserialization_field_type(fields["field2"].type) is str
assert _get_deserialization_field_type(fields["field3"].type) is NestedJson
assert _get_deserialization_field_type(fields["field4"].type) is NestedJson
assert _get_deserialization_field_type(fields["field5"].type) is NestedJson
assert _get_deserialization_field_type(fields["field6"].type) is NestedJson
assert _get_deserialization_field_type(fields["field7"].type) is NestedJson
assert _get_deserialization_field_type(fields["field8"].type) is NestedJson
assert _get_deserialization_field_type(fields["field9"].type) == list[NestedJson]
35 changes: 34 additions & 1 deletion tilebox-workflows/tilebox/workflows/task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import inspect
import json
import typing
from abc import ABC, ABCMeta, abstractmethod
from base64 import b64decode, b64encode
from collections.abc import Sequence
from dataclasses import dataclass, field, fields, is_dataclass
from types import NoneType, UnionType
from typing import Any, cast, get_args, get_origin

# from python 3.11 onwards this is available as typing.dataclass_transform:
Expand Down Expand Up @@ -350,7 +352,7 @@ def deserialize_task(task_cls: type, task_input: bytes) -> Task:
return task_cls() # empty task
if len(task_fields) == 1:
# if there is only one field, we deserialize it directly
field_type = task_fields[0].type
field_type = _get_deserialization_field_type(task_fields[0].type) # type: ignore[arg-type]
if hasattr(field_type, "FromString"): # protobuf message
value = field_type.FromString(task_input) # type: ignore[arg-type]
else:
Expand All @@ -372,6 +374,10 @@ def _deserialize_dataclass(cls: type, params: dict[str, Any]) -> Task:


def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
if value is None:
return None

field_type = _get_deserialization_field_type(field_type)
if hasattr(field_type, "FromString"):
return field_type.FromString(b64decode(value))
if is_dataclass(field_type) and isinstance(value, dict):
Expand All @@ -398,3 +404,30 @@ def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
return {k: _deserialize_value(type_args[1], v) for k, v in value.items()}

return value


def _get_deserialization_field_type(field_type: type) -> type:
"""
Get the actual underlying type we want to deserialize a field type annotated as.

This correctly handles optional and annotated type hints.

For example, all of the following fields should be deserialized as MyDataclass class

field1: MyDataclass
field2: MyDataclass | None
field3: Optional[MyDataclass]
field4: Annotated[MyDataclass, "some description"]
field5: Annotated[Optional[MyDataclass], "some description"
"""
origin = typing.get_origin(field_type)
if origin in (typing.Union, UnionType): # handle Optional[type] and 'type | None'
args = typing.get_args(field_type)
if len(args) == 2 and args[-1] == NoneType:
return _get_deserialization_field_type(args[0])
if origin == typing.Annotated:
args = typing.get_args(field_type)
if len(args) >= 1:
return _get_deserialization_field_type(args[0])

return field_type
Loading