Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:

env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VIRTUALENV_ACTIVATORS: "bash,python"

jobs:
release:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
pull_request:
branches: [main]

env:
VIRTUALENV_ACTIVATORS: "bash,python"

jobs:
check-branch:
if: ${{ github.event_name == 'pull_request' }}
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on:
pull_request:
branches: [main]

env:
VIRTUALENV_ACTIVATORS: "bash,python"

jobs:
release:
name: Release
Expand Down
15 changes: 15 additions & 0 deletions docs/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,18 @@ columns. Static-schema values use `dataframe[name: T, ...]`, column views use
- [ ] Expand runtime-layout/schema annotations beyond function and extern
parameters, applying the same behavior to both `dataframe[...]` and
`tensor[T, ...]`.

## Type System Follow-ups

- [ ] Add parser-level support for optional list, tensor, series, and dataframe
size/shape annotations once the surface-syntax restrictions are ready to
change.
- [ ] Add runtime check sidecars for assigning unknown-size values to sized
targets, including list length, tensor shape, series length, and dataframe
row count; sized annotations must not become trusted metadata until the
runtime check has passed.
- [ ] Add support for partial tensor shape constraints using ellipsis, such as
`tensor[f64, 2, ...]`, `tensor[f64, ..., 3]`, and
`tensor[f64, 2, ..., 3]`.
- [ ] Add symbolic shape variables for generic algorithms, such as
`fn dot[N](a: tensor[f64, N], b: tensor[f64, N])`.
9 changes: 8 additions & 1 deletion packages/arx/src/arx/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _shape_of(data_type: astx.DataType) -> tuple[int, ...] | None:
returns:
type: tuple[int, Ellipsis] | None
"""
if isinstance(data_type, astx.TensorType) and data_type.shape is not None:
return data_type.shape

shape = getattr(data_type, TENSOR_SHAPE_ATTR, None)
if isinstance(shape, tuple) and all(isinstance(dim, int) for dim in shape):
return cast(tuple[int, ...], shape)
Expand All @@ -79,6 +82,7 @@ def _mark_tensor_type(
type: astx.TensorType
"""
setattr(data_type, TENSOR_SURFACE_ATTR, True)
data_type.shape = shape
if shape is not None:
setattr(data_type, TENSOR_SHAPE_ATTR, shape)
return data_type
Expand Down Expand Up @@ -132,7 +136,10 @@ def tensor_type(
raise ValueError("tensor shapes must include at least one dimension")
if any(dim < 0 for dim in shape):
raise ValueError("tensor dimensions must be non-negative")
return _mark_tensor_type(astx.TensorType(element_type), shape)
return _mark_tensor_type(
astx.TensorType(element_type, shape=shape),
shape,
)


def runtime_tensor_type(element_type: astx.DataType) -> astx.TensorType:
Expand Down
100 changes: 85 additions & 15 deletions packages/astx/src/astx/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
from astx.types import AnyType


@typechecked
def _format_type_name(type_: astx.DataType) -> str:
"""
title: Return one stable DataFrame type name.
parameters:
type_:
type: astx.DataType
returns:
type: str
"""
if isinstance(type_, astx.PointerType):
if type_.pointee_type is None:
return "PointerType"
return f"PointerType[{_format_type_name(type_.pointee_type)}]"

value = str(type_)
default_name = f"{type_.__class__.__name__}: {type_.name}"
if value == default_name:
return type_.__class__.__name__
return value


@typechecked
@dataclass(frozen=True)
class DataFrameColumn:
Expand Down Expand Up @@ -93,16 +115,20 @@ class SeriesType(AnyType):
type: astx.DataType | None
nullable:
type: bool
size:
type: int | None
"""

element_type: astx.DataType | None
nullable: bool
size: int | None

def __init__(
self,
element_type: astx.DataType | None = None,
*,
nullable: bool = False,
size: int | None = None,
) -> None:
"""
title: Initialize one Series type.
Expand All @@ -111,10 +137,15 @@ def __init__(
type: astx.DataType | None
nullable:
type: bool
size:
type: int | None
"""
super().__init__()
if size is not None and size < 0:
raise ValueError("series size must be non-negative")
self.element_type = element_type
self.nullable = nullable
self.size = size

def __str__(self) -> str:
"""
Expand All @@ -124,7 +155,10 @@ def __str__(self) -> str:
"""
if self.element_type is None:
return "SeriesType"
return f"SeriesType[{self.element_type}]"
parts = [_format_type_name(self.element_type)]
if self.size is not None:
parts.append(str(self.size))
return f"SeriesType[{', '.join(parts)}]"


@typechecked
Expand All @@ -137,22 +171,32 @@ class DataFrameType(AnyType):
attributes:
columns:
type: tuple[DataFrameColumn, Ellipsis] | None
row_count:
type: int | None
"""

columns: tuple[DataFrameColumn, ...] | None
row_count: int | None

def __init__(
self,
columns: Sequence[DataFrameColumn] | None = None,
*,
row_count: int | None = None,
) -> None:
"""
title: Initialize one DataFrame type.
parameters:
columns:
type: Sequence[DataFrameColumn] | None
row_count:
type: int | None
"""
super().__init__()
if row_count is not None and row_count < 0:
raise ValueError("dataframe row count must be non-negative")
self.columns = None if columns is None else tuple(columns)
self.row_count = row_count

def __str__(self) -> str:
"""
Expand All @@ -161,11 +205,38 @@ def __str__(self) -> str:
type: str
"""
if self.columns is None:
if self.row_count is not None:
return f"DataFrameType[{self.row_count}]"
return "DataFrameType"
columns = ", ".join(
f"{column.name}: {column.type_}" for column in self.columns
)
return f"DataFrameType[{columns}]"
parts = [
f"{column.name}: {_format_type_name(column.type_)}"
for column in self.columns
]
if self.row_count is not None:
parts.append(str(self.row_count))
return f"DataFrameType[{', '.join(parts)}]"


@typechecked
def _infer_literal_row_count(
columns: tuple[DataFrameLiteralColumn, ...],
) -> int | None:
"""
title: Infer row count for a DataFrame literal when statically consistent.
parameters:
columns:
type: tuple[DataFrameLiteralColumn, Ellipsis]
returns:
type: int | None
"""
row_count: int | None = None
for column in columns:
if row_count is None:
row_count = len(column.values)
continue
if len(column.values) != row_count:
return None
return 0 if row_count is None else row_count


@typechecked
Expand Down Expand Up @@ -199,7 +270,8 @@ def __init__(
"""
super().__init__()
self.columns = tuple(columns)
self.type_ = type_ or DataFrameType()
row_count = _infer_literal_row_count(self.columns)
self.type_ = type_ or DataFrameType(row_count=row_count)

def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct:
"""
Expand All @@ -210,19 +282,17 @@ def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct:
returns:
type: astx.base.ReprStruct
"""
value = {
value: dict[str, object] = {
"columns": [
column.get_struct(simplified) for column in self.columns
],
"type": (
None
if self.type_.columns is None
else [
column.get_struct(simplified)
for column in self.type_.columns
]
),
}
if self.type_.columns is not None:
value["type"] = [
column.get_struct(simplified) for column in self.type_.columns
]
if self.type_.row_count is not None:
value["row_count"] = self.type_.row_count
return self._prepare_struct(
"DataFrameLiteral",
cast(astx.base.ReprStruct, value),
Expand Down
2 changes: 1 addition & 1 deletion packages/astx/src/astx/literals/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
super().__init__(loc)
self.elements = list(elements) # Ensure correct type
unique_types = {type(elem.type_) for elem in elements}
self.type_ = ListType([t() for t in unique_types])
self.type_ = ListType([t() for t in unique_types], size=len(elements))
self.loc = loc


Expand Down
46 changes: 42 additions & 4 deletions packages/astx/src/astx/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@
from astx.types import AnyType


@typechecked
def _format_type_name(type_: astx.DataType) -> str:
"""
title: Return one stable tensor element type name.
parameters:
type_:
type: astx.DataType
returns:
type: str
"""
if isinstance(type_, astx.PointerType):
if type_.pointee_type is None:
return "PointerType"
return f"PointerType[{_format_type_name(type_.pointee_type)}]"

value = str(type_)
default_name = f"{type_.__class__.__name__}: {type_.name}"
if value == default_name:
return type_.__class__.__name__
return value


@typechecked
class TensorType(AnyType):
"""
Expand All @@ -26,19 +48,32 @@ class TensorType(AnyType):
attributes:
element_type:
type: astx.DataType | None
shape:
type: tuple[int, Ellipsis] | None
"""

element_type: astx.DataType | None
shape: tuple[int, ...] | None

def __init__(self, element_type: astx.DataType | None = None) -> None:
def __init__(
self,
element_type: astx.DataType | None = None,
*,
shape: Sequence[int] | None = None,
) -> None:
"""
title: Initialize one Tensor type.
parameters:
element_type:
type: astx.DataType | None
shape:
type: Sequence[int] | None
"""
super().__init__()
self.element_type = element_type
self.shape = None if shape is None else tuple(shape)
if self.shape is not None and any(dim < 0 for dim in self.shape):
raise ValueError("tensor shape dimensions must be non-negative")

def __str__(self) -> str:
"""
Expand All @@ -48,7 +83,10 @@ def __str__(self) -> str:
"""
if self.element_type is None:
return "TensorType"
return f"TensorType[{self.element_type}]"
parts = [_format_type_name(self.element_type)]
if self.shape is not None:
parts.extend(str(dimension) for dimension in self.shape)
return f"TensorType[{', '.join(parts)}]"


@typechecked
Expand Down Expand Up @@ -109,7 +147,7 @@ def __init__(
self.shape = tuple(shape)
self.strides = None if strides is None else tuple(strides)
self.offset_bytes = offset_bytes
self.type_ = TensorType(element_type)
self.type_ = TensorType(element_type, shape=self.shape)

def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct:
"""
Expand Down Expand Up @@ -185,7 +223,7 @@ def __init__(
self.shape = tuple(shape)
self.strides = None if strides is None else tuple(strides)
self.offset_bytes = offset_bytes
self.type_ = TensorType()
self.type_ = TensorType(shape=self.shape)

def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct:
"""
Expand Down
Loading
Loading