diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index dbf27b1..924d3bf 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -9,6 +9,7 @@ on: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + VIRTUALENV_ACTIVATORS: "bash,python" jobs: release: diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 6b3a36d..12048c4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -6,6 +6,9 @@ on: pull_request: branches: [main] +env: + VIRTUALENV_ACTIVATORS: "bash,python" + jobs: check-branch: if: ${{ github.event_name == 'pull_request' }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 39ce668..9464285 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -7,6 +7,9 @@ on: pull_request: branches: [main] +env: + VIRTUALENV_ACTIVATORS: "bash,python" + jobs: release: name: Release diff --git a/docs/roadmap.md b/docs/roadmap.md index 839a8ca..5f06d93 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -72,3 +72,18 @@ columns. Static-schema values use `dataframe[name: T, ...]`, column views use - [ ] Expand runtime-layout/schema annotations beyond function and extern parameters, applying the same behavior to both `dataframe[...]` and `tensor[T, ...]`. + +## Type System Follow-ups + +- [ ] Add parser-level support for optional list, tensor, series, and dataframe + size/shape annotations once the surface-syntax restrictions are ready to + change. +- [ ] Add runtime check sidecars for assigning unknown-size values to sized + targets, including list length, tensor shape, series length, and dataframe + row count; sized annotations must not become trusted metadata until the + runtime check has passed. +- [ ] Add support for partial tensor shape constraints using ellipsis, such as + `tensor[f64, 2, ...]`, `tensor[f64, ..., 3]`, and + `tensor[f64, 2, ..., 3]`. +- [ ] Add symbolic shape variables for generic algorithms, such as + `fn dot[N](a: tensor[f64, N], b: tensor[f64, N])`. diff --git a/packages/arx/src/arx/tensor.py b/packages/arx/src/arx/tensor.py index f956764..e8645ad 100644 --- a/packages/arx/src/arx/tensor.py +++ b/packages/arx/src/arx/tensor.py @@ -58,6 +58,9 @@ def _shape_of(data_type: astx.DataType) -> tuple[int, ...] | None: returns: type: tuple[int, Ellipsis] | None """ + if isinstance(data_type, astx.TensorType) and data_type.shape is not None: + return data_type.shape + shape = getattr(data_type, TENSOR_SHAPE_ATTR, None) if isinstance(shape, tuple) and all(isinstance(dim, int) for dim in shape): return cast(tuple[int, ...], shape) @@ -79,6 +82,7 @@ def _mark_tensor_type( type: astx.TensorType """ setattr(data_type, TENSOR_SURFACE_ATTR, True) + data_type.shape = shape if shape is not None: setattr(data_type, TENSOR_SHAPE_ATTR, shape) return data_type @@ -132,7 +136,10 @@ def tensor_type( raise ValueError("tensor shapes must include at least one dimension") if any(dim < 0 for dim in shape): raise ValueError("tensor dimensions must be non-negative") - return _mark_tensor_type(astx.TensorType(element_type), shape) + return _mark_tensor_type( + astx.TensorType(element_type, shape=shape), + shape, + ) def runtime_tensor_type(element_type: astx.DataType) -> astx.TensorType: diff --git a/packages/astx/src/astx/dataframe.py b/packages/astx/src/astx/dataframe.py index 494317b..1210800 100644 --- a/packages/astx/src/astx/dataframe.py +++ b/packages/astx/src/astx/dataframe.py @@ -17,6 +17,28 @@ from astx.types import AnyType +@typechecked +def _format_type_name(type_: astx.DataType) -> str: + """ + title: Return one stable DataFrame type name. + parameters: + type_: + type: astx.DataType + returns: + type: str + """ + if isinstance(type_, astx.PointerType): + if type_.pointee_type is None: + return "PointerType" + return f"PointerType[{_format_type_name(type_.pointee_type)}]" + + value = str(type_) + default_name = f"{type_.__class__.__name__}: {type_.name}" + if value == default_name: + return type_.__class__.__name__ + return value + + @typechecked @dataclass(frozen=True) class DataFrameColumn: @@ -93,16 +115,20 @@ class SeriesType(AnyType): type: astx.DataType | None nullable: type: bool + size: + type: int | None """ element_type: astx.DataType | None nullable: bool + size: int | None def __init__( self, element_type: astx.DataType | None = None, *, nullable: bool = False, + size: int | None = None, ) -> None: """ title: Initialize one Series type. @@ -111,10 +137,15 @@ def __init__( type: astx.DataType | None nullable: type: bool + size: + type: int | None """ super().__init__() + if size is not None and size < 0: + raise ValueError("series size must be non-negative") self.element_type = element_type self.nullable = nullable + self.size = size def __str__(self) -> str: """ @@ -124,7 +155,10 @@ def __str__(self) -> str: """ if self.element_type is None: return "SeriesType" - return f"SeriesType[{self.element_type}]" + parts = [_format_type_name(self.element_type)] + if self.size is not None: + parts.append(str(self.size)) + return f"SeriesType[{', '.join(parts)}]" @typechecked @@ -137,22 +171,32 @@ class DataFrameType(AnyType): attributes: columns: type: tuple[DataFrameColumn, Ellipsis] | None + row_count: + type: int | None """ columns: tuple[DataFrameColumn, ...] | None + row_count: int | None def __init__( self, columns: Sequence[DataFrameColumn] | None = None, + *, + row_count: int | None = None, ) -> None: """ title: Initialize one DataFrame type. parameters: columns: type: Sequence[DataFrameColumn] | None + row_count: + type: int | None """ super().__init__() + if row_count is not None and row_count < 0: + raise ValueError("dataframe row count must be non-negative") self.columns = None if columns is None else tuple(columns) + self.row_count = row_count def __str__(self) -> str: """ @@ -161,11 +205,38 @@ def __str__(self) -> str: type: str """ if self.columns is None: + if self.row_count is not None: + return f"DataFrameType[{self.row_count}]" return "DataFrameType" - columns = ", ".join( - f"{column.name}: {column.type_}" for column in self.columns - ) - return f"DataFrameType[{columns}]" + parts = [ + f"{column.name}: {_format_type_name(column.type_)}" + for column in self.columns + ] + if self.row_count is not None: + parts.append(str(self.row_count)) + return f"DataFrameType[{', '.join(parts)}]" + + +@typechecked +def _infer_literal_row_count( + columns: tuple[DataFrameLiteralColumn, ...], +) -> int | None: + """ + title: Infer row count for a DataFrame literal when statically consistent. + parameters: + columns: + type: tuple[DataFrameLiteralColumn, Ellipsis] + returns: + type: int | None + """ + row_count: int | None = None + for column in columns: + if row_count is None: + row_count = len(column.values) + continue + if len(column.values) != row_count: + return None + return 0 if row_count is None else row_count @typechecked @@ -199,7 +270,8 @@ def __init__( """ super().__init__() self.columns = tuple(columns) - self.type_ = type_ or DataFrameType() + row_count = _infer_literal_row_count(self.columns) + self.type_ = type_ or DataFrameType(row_count=row_count) def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: """ @@ -210,19 +282,17 @@ def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: returns: type: astx.base.ReprStruct """ - value = { + value: dict[str, object] = { "columns": [ column.get_struct(simplified) for column in self.columns ], - "type": ( - None - if self.type_.columns is None - else [ - column.get_struct(simplified) - for column in self.type_.columns - ] - ), } + if self.type_.columns is not None: + value["type"] = [ + column.get_struct(simplified) for column in self.type_.columns + ] + if self.type_.row_count is not None: + value["row_count"] = self.type_.row_count return self._prepare_struct( "DataFrameLiteral", cast(astx.base.ReprStruct, value), diff --git a/packages/astx/src/astx/literals/collections.py b/packages/astx/src/astx/literals/collections.py index f953340..ff7eaa2 100644 --- a/packages/astx/src/astx/literals/collections.py +++ b/packages/astx/src/astx/literals/collections.py @@ -46,7 +46,7 @@ def __init__( super().__init__(loc) self.elements = list(elements) # Ensure correct type unique_types = {type(elem.type_) for elem in elements} - self.type_ = ListType([t() for t in unique_types]) + self.type_ = ListType([t() for t in unique_types], size=len(elements)) self.loc = loc diff --git a/packages/astx/src/astx/tensor.py b/packages/astx/src/astx/tensor.py index 3c9a701..b80f598 100644 --- a/packages/astx/src/astx/tensor.py +++ b/packages/astx/src/astx/tensor.py @@ -16,6 +16,28 @@ from astx.types import AnyType +@typechecked +def _format_type_name(type_: astx.DataType) -> str: + """ + title: Return one stable tensor element type name. + parameters: + type_: + type: astx.DataType + returns: + type: str + """ + if isinstance(type_, astx.PointerType): + if type_.pointee_type is None: + return "PointerType" + return f"PointerType[{_format_type_name(type_.pointee_type)}]" + + value = str(type_) + default_name = f"{type_.__class__.__name__}: {type_.name}" + if value == default_name: + return type_.__class__.__name__ + return value + + @typechecked class TensorType(AnyType): """ @@ -26,19 +48,32 @@ class TensorType(AnyType): attributes: element_type: type: astx.DataType | None + shape: + type: tuple[int, Ellipsis] | None """ element_type: astx.DataType | None + shape: tuple[int, ...] | None - def __init__(self, element_type: astx.DataType | None = None) -> None: + def __init__( + self, + element_type: astx.DataType | None = None, + *, + shape: Sequence[int] | None = None, + ) -> None: """ title: Initialize one Tensor type. parameters: element_type: type: astx.DataType | None + shape: + type: Sequence[int] | None """ super().__init__() self.element_type = element_type + self.shape = None if shape is None else tuple(shape) + if self.shape is not None and any(dim < 0 for dim in self.shape): + raise ValueError("tensor shape dimensions must be non-negative") def __str__(self) -> str: """ @@ -48,7 +83,10 @@ def __str__(self) -> str: """ if self.element_type is None: return "TensorType" - return f"TensorType[{self.element_type}]" + parts = [_format_type_name(self.element_type)] + if self.shape is not None: + parts.extend(str(dimension) for dimension in self.shape) + return f"TensorType[{', '.join(parts)}]" @typechecked @@ -109,7 +147,7 @@ def __init__( self.shape = tuple(shape) self.strides = None if strides is None else tuple(strides) self.offset_bytes = offset_bytes - self.type_ = TensorType(element_type) + self.type_ = TensorType(element_type, shape=self.shape) def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: """ @@ -185,7 +223,7 @@ def __init__( self.shape = tuple(shape) self.strides = None if strides is None else tuple(strides) self.offset_bytes = offset_bytes - self.type_ = TensorType() + self.type_ = TensorType(shape=self.shape) def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: """ diff --git a/packages/astx/src/astx/types/collections.py b/packages/astx/src/astx/types/collections.py index 8762b31..a609ae5 100644 --- a/packages/astx/src/astx/types/collections.py +++ b/packages/astx/src/astx/types/collections.py @@ -11,6 +11,27 @@ from astx.types.base import AnyType +@typechecked +def _format_type_name(type_: ExprType) -> str: + """ + title: Return one stable collection element type name. + parameters: + type_: + type: ExprType + returns: + type: str + """ + pointee_type = getattr(type_, "pointee_type", None) + if isinstance(pointee_type, ExprType): + return f"{type_.__class__.__name__}[{_format_type_name(pointee_type)}]" + + value = str(type_) + default_name = f"{type_.__class__.__name__}: {getattr(type_, 'name', '')}" + if value == default_name: + return type_.__class__.__name__ + return value + + @public @typechecked class CollectionType(AnyType): @@ -27,18 +48,31 @@ class ListType(CollectionType): attributes: element_types: type: list[ExprType] + size: + type: int | None """ element_types: list[ExprType] + size: int | None - def __init__(self, element_types: list[ExprType]) -> None: + def __init__( + self, + element_types: list[ExprType], + *, + size: int | None = None, + ) -> None: """ title: Initialize ListType with an element type. parameters: element_types: type: list[ExprType] + size: + type: int | None """ + if size is not None and size < 0: + raise ValueError("list size must be non-negative") self.element_types = element_types + self.size = size def __str__(self) -> str: """ @@ -46,7 +80,10 @@ def __str__(self) -> str: returns: type: str """ - types_str = ", ".join(str(t) for t in self.element_types) + parts = [_format_type_name(type_) for type_ in self.element_types] + if self.size is not None: + parts.append(str(self.size)) + types_str = ", ".join(parts) return f"ListType[{types_str}]" diff --git a/packages/astx/tests/test_sized_type_metadata.py b/packages/astx/tests/test_sized_type_metadata.py new file mode 100644 index 0000000..35cb69a --- /dev/null +++ b/packages/astx/tests/test_sized_type_metadata.py @@ -0,0 +1,154 @@ +""" +title: Tests for optional sized type metadata. +""" + +from __future__ import annotations + +import pytest + +import astx + +LIST_SIZE = 4 +SERIES_SIZE = 100 +DATAFRAME_ROW_COUNT = 100 +INFERRED_ROW_COUNT = 2 + + +def test_list_type_optional_size_metadata() -> None: + """ + title: ListType stores optional static sizes. + """ + unconstrained = astx.ListType([astx.Int32()]) + sized = astx.ListType([astx.Int32()], size=LIST_SIZE) + + assert unconstrained.size is None + assert sized.size == LIST_SIZE + assert str(unconstrained) == "ListType[Int32]" + assert str(sized) == "ListType[Int32, 4]" + with pytest.raises(ValueError, match="non-negative"): + astx.ListType([astx.Int32()], size=-1) + + +def test_sized_type_strings_preserve_nested_type_detail() -> None: + """ + title: Sized type strings preserve nested parameterized type details. + """ + pointer = astx.PointerType(astx.Int32()) + + assert str(astx.ListType([pointer], size=LIST_SIZE)) == ( + "ListType[PointerType[Int32], 4]" + ) + + +def test_tensor_type_optional_shape_metadata() -> None: + """ + title: TensorType stores optional static shapes. + """ + unconstrained = astx.TensorType(astx.Float64()) + shaped = astx.TensorType(astx.Float64(), shape=(2, 3)) + + assert unconstrained.shape is None + assert shaped.shape == (2, 3) + assert str(unconstrained) == "TensorType[Float64]" + assert str(shaped) == "TensorType[Float64, 2, 3]" + with pytest.raises(ValueError, match="non-negative"): + astx.TensorType(astx.Float64(), shape=(2, -1)) + + +def test_tensor_type_string_preserves_nested_type_detail() -> None: + """ + title: Tensor type strings preserve nested parameterized type details. + """ + tensor = astx.TensorType(astx.PointerType(astx.Int32()), shape=(2, 3)) + + assert str(tensor) == "TensorType[PointerType[Int32], 2, 3]" + + +def test_tensor_literal_and_view_preserve_shape_metadata() -> None: + """ + title: Tensor-producing nodes preserve static shapes on their types. + """ + literal = astx.TensorLiteral( + [ + astx.LiteralFloat64(1.0), + astx.LiteralFloat64(2.0), + astx.LiteralFloat64(3.0), + astx.LiteralFloat64(4.0), + astx.LiteralFloat64(5.0), + astx.LiteralFloat64(6.0), + ], + element_type=astx.Float64(), + shape=(2, 3), + ) + view = astx.TensorView(literal, shape=(3, 2)) + + assert literal.type_.shape == (2, 3) + assert view.type_.shape == (3, 2) + + +def test_series_type_optional_size_metadata() -> None: + """ + title: SeriesType stores optional static sizes. + """ + unconstrained = astx.SeriesType(astx.Int32()) + sized = astx.SeriesType(astx.Int32(), size=SERIES_SIZE) + + assert unconstrained.size is None + assert sized.size == SERIES_SIZE + assert str(unconstrained) == "SeriesType[Int32]" + assert str(sized) == "SeriesType[Int32, 100]" + with pytest.raises(ValueError, match="non-negative"): + astx.SeriesType(astx.Int32(), size=-1) + + +def test_dataframe_type_optional_row_count_metadata() -> None: + """ + title: DataFrameType stores optional static row counts. + """ + columns = (astx.DataFrameColumn("age", astx.Int32()),) + unconstrained = astx.DataFrameType(columns) + sized = astx.DataFrameType(columns, row_count=DATAFRAME_ROW_COUNT) + + assert unconstrained.row_count is None + assert sized.row_count == DATAFRAME_ROW_COUNT + assert str(unconstrained) == "DataFrameType[age: Int32]" + assert str(sized) == "DataFrameType[age: Int32, 100]" + with pytest.raises(ValueError, match="non-negative"): + astx.DataFrameType(columns, row_count=-1) + + +def test_dataframe_type_string_preserves_nested_type_detail() -> None: + """ + title: DataFrame type strings preserve nested parameterized type details. + """ + rows = astx.DataFrameType( + (astx.DataFrameColumn("ptr", astx.PointerType(astx.Int32())),), + row_count=DATAFRAME_ROW_COUNT, + ) + + assert str(rows) == "DataFrameType[ptr: PointerType[Int32], 100]" + + +def test_dataframe_literal_infers_unconstrained_row_count() -> None: + """ + title: DataFrameLiteral infers row counts when no explicit type is given. + """ + literal = astx.DataFrameLiteral( + ( + astx.DataFrameLiteralColumn( + "age", + (astx.LiteralInt32(1), astx.LiteralInt32(2)), + ), + ) + ) + explicit_type = astx.DataFrameType( + (astx.DataFrameColumn("age", astx.Int32()),), + row_count=DATAFRAME_ROW_COUNT, + ) + explicit_literal = astx.DataFrameLiteral( + literal.columns, + type_=explicit_type, + ) + + assert literal.type_.row_count == INFERRED_ROW_COUNT + assert explicit_literal.type_.row_count == DATAFRAME_ROW_COUNT diff --git a/packages/irx/src/irx/analysis/handlers/_expressions/dataframes.py b/packages/irx/src/irx/analysis/handlers/_expressions/dataframes.py index 020e0ad..a24d1b4 100644 --- a/packages/irx/src/irx/analysis/handlers/_expressions/dataframes.py +++ b/packages/irx/src/irx/analysis/handlers/_expressions/dataframes.py @@ -111,6 +111,7 @@ def _visit_dataframe_column_access( node.type_ = astx.SeriesType( column.type_, nullable=column.nullable, + size=base_type.row_count, ) self._semantic(node).extras[DATAFRAME_COLUMN_INDEX_EXTRA] = ( column.index diff --git a/packages/irx/src/irx/analysis/handlers/_expressions/tensors.py b/packages/irx/src/irx/analysis/handlers/_expressions/tensors.py index 7ae4292..4850065 100644 --- a/packages/irx/src/irx/analysis/handlers/_expressions/tensors.py +++ b/packages/irx/src/irx/analysis/handlers/_expressions/tensors.py @@ -144,7 +144,7 @@ def visit(self, node: astx.TensorLiteral) -> None: self._semantic(node).extras[TENSOR_ELEMENT_TYPE_EXTRA] = ( node.element_type ) - node.type_ = astx.TensorType(node.element_type) + node.type_ = astx.TensorType(node.element_type, shape=shape) self._set_type(node, node.type_) @SemanticAnalyzerCore.visit.dispatch @@ -300,7 +300,7 @@ def visit(self, node: astx.TensorView) -> None: self._semantic(node).extras[TENSOR_ELEMENT_TYPE_EXTRA] = ( element_type ) - node.type_ = astx.TensorType(element_type) + node.type_ = astx.TensorType(element_type, shape=shape) self._semantic(node).extras[TENSOR_LAYOUT_EXTRA] = layout self._semantic(node).extras[TENSOR_FLAGS_EXTRA] = flags self._set_type(node, node.type_) diff --git a/packages/irx/src/irx/analysis/handlers/_templates/support.py b/packages/irx/src/irx/analysis/handlers/_templates/support.py index e4a4289..a7375e9 100644 --- a/packages/irx/src/irx/analysis/handlers/_templates/support.py +++ b/packages/irx/src/irx/analysis/handlers/_templates/support.py @@ -199,7 +199,7 @@ def _substitute_type( if type_.element_type is not None else None ) - return astx.TensorType(element_type) + return astx.TensorType(element_type, shape=type_.shape) return clone_type(type_) def _template_bindings_map( diff --git a/packages/irx/src/irx/analysis/types.py b/packages/irx/src/irx/analysis/types.py index d5ad051..72f4b93 100644 --- a/packages/irx/src/irx/analysis/types.py +++ b/packages/irx/src/irx/analysis/types.py @@ -60,6 +60,152 @@ } +@typechecked +def _same_optional_size(lhs: int | None, rhs: int | None) -> bool: + """ + title: Return whether optional collection sizes are compatible. + parameters: + lhs: + type: int | None + rhs: + type: int | None + returns: + type: bool + """ + return lhs is None or rhs is None or lhs == rhs + + +@typechecked +def _same_optional_shape( + lhs: tuple[int, ...] | None, + rhs: tuple[int, ...] | None, +) -> bool: + """ + title: Return whether optional tensor shapes are compatible. + parameters: + lhs: + type: tuple[int, Ellipsis] | None + rhs: + type: tuple[int, Ellipsis] | None + returns: + type: bool + """ + return lhs is None or rhs is None or lhs == rhs + + +@typechecked +def _target_accepts_known_size( + target_size: int | None, + value_size: int | None, +) -> bool: + """ + title: Return whether a target size accepts a value size. + parameters: + target_size: + type: int | None + value_size: + type: int | None + returns: + type: bool + """ + return target_size is None or ( + value_size is not None and target_size == value_size + ) + + +@typechecked +def _target_accepts_known_shape( + target_shape: tuple[int, ...] | None, + value_shape: tuple[int, ...] | None, +) -> bool: + """ + title: Return whether a target shape accepts a value shape. + parameters: + target_shape: + type: tuple[int, Ellipsis] | None + value_shape: + type: tuple[int, Ellipsis] | None + returns: + type: bool + """ + return target_shape is None or ( + value_shape is not None and target_shape == value_shape + ) + + +@typechecked +def _metadata_assignment_compatible( + target: astx.DataType, + value: astx.DataType, +) -> bool: + """ + title: Return whether assignment metadata is statically compatible. + parameters: + target: + type: astx.DataType + value: + type: astx.DataType + returns: + type: bool + """ + if isinstance(target, astx.ListType) and isinstance(value, astx.ListType): + return _target_accepts_known_size(target.size, value.size) + if isinstance(target, astx.TensorType) and isinstance( + value, + astx.TensorType, + ): + return _target_accepts_known_shape(target.shape, value.shape) + if isinstance(target, astx.SeriesType) and isinstance( + value, + astx.SeriesType, + ): + return _target_accepts_known_size(target.size, value.size) + if isinstance(target, astx.DataFrameType) and isinstance( + value, + astx.DataFrameType, + ): + return _target_accepts_known_size(target.row_count, value.row_count) + return True + + +@public +@typechecked +def requires_size_check( + target_size: int | None, + value_size: int | None, +) -> bool: + """ + title: Return whether assignment requires a runtime size check. + parameters: + target_size: + type: int | None + value_size: + type: int | None + returns: + type: bool + """ + return target_size is not None and value_size is None + + +@public +@typechecked +def requires_shape_check( + target_shape: tuple[int, ...] | None, + value_shape: tuple[int, ...] | None, +) -> bool: + """ + title: Return whether assignment requires a runtime shape check. + parameters: + target_shape: + type: tuple[int, Ellipsis] | None + value_shape: + type: tuple[int, Ellipsis] | None + returns: + type: bool + """ + return target_shape is not None and value_shape is None + + @public @typechecked def clone_type(type_: astx.DataType) -> astx.DataType: @@ -116,7 +262,8 @@ def clone_type(type_: astx.DataType) -> astx.DataType: [ clone_type(cast(astx.DataType, element_type)) for element_type in type_.element_types - ] + ], + size=type_.size, ) if isinstance(type_, astx.TupleType): return astx.TupleType( @@ -151,17 +298,21 @@ def clone_type(type_: astx.DataType) -> astx.DataType: if type_.element_type is not None else None ) - return astx.TensorType(element_type) + return astx.TensorType(element_type, shape=type_.shape) if isinstance(type_, astx.SeriesType): element_type = ( clone_type(type_.element_type) if type_.element_type is not None else None ) - return astx.SeriesType(element_type, nullable=type_.nullable) + return astx.SeriesType( + element_type, + nullable=type_.nullable, + size=type_.size, + ) if isinstance(type_, astx.DataFrameType): if type_.columns is None: - return astx.DataFrameType() + return astx.DataFrameType(row_count=type_.row_count) return astx.DataFrameType( tuple( astx.DataFrameColumn( @@ -170,7 +321,8 @@ def clone_type(type_: astx.DataType) -> astx.DataType: nullable=column.nullable, ) for column in type_.columns - ) + ), + row_count=type_.row_count, ) return type_.__class__() @@ -211,15 +363,16 @@ def display_type_name(type_: astx.DataType | None) -> str: return f"PointerType[{display_type_name(type_.pointee_type)}]" if isinstance(type_, astx.ListType): if not type_.element_types: + if type_.size is not None: + return f"ListType[{type_.size}]" return "ListType" - return ( - "ListType[" - + ", ".join( - display_type_name(cast(astx.DataType, member)) - for member in type_.element_types - ) - + "]" - ) + parts = [ + display_type_name(cast(astx.DataType, member)) + for member in type_.element_types + ] + if type_.size is not None: + parts.append(str(type_.size)) + return "ListType[" + ", ".join(parts) + "]" if isinstance(type_, astx.TupleType): if not type_.element_types: return "TupleType" @@ -249,19 +402,29 @@ def display_type_name(type_: astx.DataType | None) -> str: if isinstance(type_, astx.TensorType): if type_.element_type is None: return "TensorType" - return f"TensorType[{display_type_name(type_.element_type)}]" + parts = [display_type_name(type_.element_type)] + if type_.shape is not None: + parts.extend(str(dimension) for dimension in type_.shape) + return "TensorType[" + ", ".join(parts) + "]" if isinstance(type_, astx.SeriesType): if type_.element_type is None: return "SeriesType" - return f"SeriesType[{display_type_name(type_.element_type)}]" + parts = [display_type_name(type_.element_type)] + if type_.size is not None: + parts.append(str(type_.size)) + return "SeriesType[" + ", ".join(parts) + "]" if isinstance(type_, astx.DataFrameType): if type_.columns is None: + if type_.row_count is not None: + return f"DataFrameType[{type_.row_count}]" return "DataFrameType" - columns = ", ".join( + parts = [ f"{column.name}: {display_type_name(column.type_)}" for column in type_.columns - ) - return f"DataFrameType[{columns}]" + ] + if type_.row_count is not None: + parts.append(str(type_.row_count)) + return "DataFrameType[" + ", ".join(parts) + "]" return str(type_.__class__.__name__) @@ -320,6 +483,8 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: return lhs.pointee_type is None and rhs.pointee_type is None return same_type(lhs.pointee_type, rhs.pointee_type) if isinstance(lhs, astx.ListType) and isinstance(rhs, astx.ListType): + if not _same_optional_size(lhs.size, rhs.size): + return False if len(lhs.element_types) != len(rhs.element_types): return False return all( @@ -374,6 +539,8 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: rhs, astx.TensorType, ): + if not _same_optional_shape(lhs.shape, rhs.shape): + return False if lhs.element_type is None or rhs.element_type is None: return True return same_type(lhs.element_type, rhs.element_type) @@ -381,6 +548,8 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: rhs, astx.SeriesType, ): + if not _same_optional_size(lhs.size, rhs.size): + return False if lhs.element_type is None or rhs.element_type is None: return True return lhs.nullable == rhs.nullable and same_type( @@ -390,6 +559,8 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: rhs, astx.DataFrameType, ): + if not _same_optional_size(lhs.row_count, rhs.row_count): + return False if lhs.columns is None or rhs.columns is None: return True if len(lhs.columns) != len(rhs.columns): @@ -784,6 +955,8 @@ def is_assignable( """ if target is None or value is None: return True + if not _metadata_assignment_compatible(target, value): + return False if same_type(target, value): return True if isinstance(target, astx.UnionType): @@ -800,6 +973,8 @@ def is_assignable( ): return is_assignable(target.yield_type, value.yield_type) if isinstance(target, astx.ListType) and isinstance(value, astx.ListType): + if not _target_accepts_known_size(target.size, value.size): + return False if not target.element_types or not value.element_types: return True target_members = [ @@ -844,6 +1019,47 @@ def is_assignable( cast(astx.DataType, target.value_type), cast(astx.DataType, value.value_type), ) + if isinstance(target, astx.TensorType) and isinstance( + value, + astx.TensorType, + ): + if not _target_accepts_known_shape(target.shape, value.shape): + return False + if target.element_type is None or value.element_type is None: + return True + return same_type(target.element_type, value.element_type) + if isinstance(target, astx.SeriesType) and isinstance( + value, + astx.SeriesType, + ): + if not _target_accepts_known_size(target.size, value.size): + return False + if target.element_type is None or value.element_type is None: + return True + return target.nullable == value.nullable and same_type( + target.element_type, + value.element_type, + ) + if isinstance(target, astx.DataFrameType) and isinstance( + value, + astx.DataFrameType, + ): + if not _target_accepts_known_size(target.row_count, value.row_count): + return False + if target.columns is None or value.columns is None: + return True + if len(target.columns) != len(value.columns): + return False + return all( + target_column.name == value_column.name + and target_column.nullable == value_column.nullable + and same_type(target_column.type_, value_column.type_) + for target_column, value_column in zip( + target.columns, + value.columns, + strict=True, + ) + ) if isinstance(target, astx.ClassType) and isinstance( value, astx.ClassType ): diff --git a/packages/irx/tests/analysis/test_sized_type_metadata.py b/packages/irx/tests/analysis/test_sized_type_metadata.py new file mode 100644 index 0000000..8067bc6 --- /dev/null +++ b/packages/irx/tests/analysis/test_sized_type_metadata.py @@ -0,0 +1,170 @@ +""" +title: Tests for sized type metadata helpers. +""" + +from __future__ import annotations + +import astx + +from irx.analysis.types import ( + clone_type, + display_type_name, + is_assignable, + requires_shape_check, + requires_size_check, + same_type, +) + +LIST_SIZE = 4 +SERIES_SIZE = 100 +DATAFRAME_ROW_COUNT = 100 + + +def _dataframe_columns() -> tuple[astx.DataFrameColumn, ...]: + """ + title: Build a small dataframe schema for metadata tests. + returns: + type: tuple[astx.DataFrameColumn, Ellipsis] + """ + return (astx.DataFrameColumn("age", astx.Int32()),) + + +def test_clone_type_preserves_sized_metadata() -> None: + """ + title: clone_type preserves collection size and shape metadata. + """ + list_type = clone_type(astx.ListType([astx.Int32()], size=LIST_SIZE)) + tensor_type = clone_type(astx.TensorType(astx.Float64(), shape=(2, 3))) + series_type = clone_type(astx.SeriesType(astx.Int32(), size=SERIES_SIZE)) + dataframe_type = clone_type( + astx.DataFrameType( + _dataframe_columns(), + row_count=DATAFRAME_ROW_COUNT, + ) + ) + + assert isinstance(list_type, astx.ListType) + assert list_type.size == LIST_SIZE + assert isinstance(tensor_type, astx.TensorType) + assert tensor_type.shape == (2, 3) + assert isinstance(series_type, astx.SeriesType) + assert series_type.size == SERIES_SIZE + assert isinstance(dataframe_type, astx.DataFrameType) + assert dataframe_type.row_count == DATAFRAME_ROW_COUNT + + +def test_display_type_name_canonicalizes_sized_metadata() -> None: + """ + title: display_type_name omits unconstrained and renders constrained sizes. + """ + assert display_type_name(astx.ListType([astx.Int32()])) == ( + "ListType[Int32]" + ) + assert display_type_name(astx.TensorType(astx.Float64())) == ( + "TensorType[Float64]" + ) + assert display_type_name(astx.SeriesType(astx.Int32())) == ( + "SeriesType[Int32]" + ) + assert display_type_name(astx.DataFrameType(_dataframe_columns())) == ( + "DataFrameType[age: Int32]" + ) + + assert display_type_name(astx.ListType([astx.Int32()], size=4)) == ( + "ListType[Int32, 4]" + ) + assert ( + display_type_name(astx.TensorType(astx.Float64(), shape=(2, 3))) + == "TensorType[Float64, 2, 3]" + ) + assert display_type_name(astx.SeriesType(astx.Int32(), size=100)) == ( + "SeriesType[Int32, 100]" + ) + assert ( + display_type_name( + astx.DataFrameType(_dataframe_columns(), row_count=100) + ) + == "DataFrameType[age: Int32, 100]" + ) + + +def test_same_type_allows_unconstrained_tensor_shape_wildcard() -> None: + """ + title: Tensor same_type keeps unconstrained shape wildcard behavior. + """ + assert same_type( + astx.TensorType(astx.Float64()), + astx.TensorType(astx.Float64(), shape=(2, 3)), + ) + assert not same_type( + astx.TensorType(astx.Float64(), shape=(2, 3)), + astx.TensorType(astx.Float64(), shape=(3, 2)), + ) + + +def test_tensor_assignability_checks_known_shapes() -> None: + """ + title: Tensor assignability rejects unchecked shape narrowing. + """ + assert is_assignable( + astx.TensorType(astx.Float64(), shape=(2, 3)), + astx.TensorType(astx.Float64(), shape=(2, 3)), + ) + assert not is_assignable( + astx.TensorType(astx.Float64(), shape=(2, 3)), + astx.TensorType(astx.Float64(), shape=(3, 2)), + ) + assert is_assignable( + astx.TensorType(astx.Float64()), + astx.TensorType(astx.Float64(), shape=(2, 3)), + ) + assert not is_assignable( + astx.TensorType(astx.Float64(), shape=(2, 3)), + astx.TensorType(astx.Float64()), + ) + + +def test_sized_assignability_rejects_unchecked_size_narrowing() -> None: + """ + title: Sized assignment rejects unknown-to-known narrowing for all types. + """ + assert not is_assignable( + astx.ListType([astx.Int32()], size=LIST_SIZE), + astx.ListType([astx.Int32()]), + ) + assert is_assignable( + astx.ListType([astx.Int32()]), + astx.ListType([astx.Int32()], size=LIST_SIZE), + ) + assert not is_assignable( + astx.SeriesType(astx.Int32(), size=SERIES_SIZE), + astx.SeriesType(astx.Int32()), + ) + assert is_assignable( + astx.SeriesType(astx.Int32()), + astx.SeriesType(astx.Int32(), size=SERIES_SIZE), + ) + assert not is_assignable( + astx.DataFrameType( + _dataframe_columns(), + row_count=DATAFRAME_ROW_COUNT, + ), + astx.DataFrameType(_dataframe_columns()), + ) + assert is_assignable( + astx.DataFrameType(_dataframe_columns()), + astx.DataFrameType( + _dataframe_columns(), + row_count=DATAFRAME_ROW_COUNT, + ), + ) + + +def test_runtime_check_helpers_detect_unknown_metadata_narrowing() -> None: + """ + title: Runtime-check helpers identify unknown-to-known size narrowing. + """ + assert requires_size_check(4, None) + assert not requires_size_check(None, 4) + assert requires_shape_check((2, 3), None) + assert not requires_shape_check(None, (2, 3))