diff --git a/docs/library/built-in-types.md b/docs/library/built-in-types.md index bc820fb..4fbbbcb 100644 --- a/docs/library/built-in-types.md +++ b/docs/library/built-in-types.md @@ -174,7 +174,7 @@ fn dataframe_demo() -> i32: return cast(rows.nrows(), i32) ``` -## Casting +## Type-aware builtins Use the built-in `cast(value, type)` helper to convert values between supported types. @@ -184,6 +184,25 @@ fn cast_demo(a: i32) -> str: return cast(a, str) ``` +Use `isinstance(value, type)` to compare the static semantic type of a value +with a concrete type, type alias, or finite union type. + +```arx +type Number = i32 | i64 + +fn check(value: i32) -> bool: + return isinstance(value, Number) +``` + +Use `type(value)` to produce the value's semantic type name as `str`. + +```arx +type Count = i32 + +fn type_name(value: Count) -> str: + return type(value) +``` + ## See Also - [Data Types](datatypes.md) for annotation rules and placement diff --git a/docs/library/datatypes.md b/docs/library/datatypes.md index ae3c7ff..3a3c7c0 100644 --- a/docs/library/datatypes.md +++ b/docs/library/datatypes.md @@ -41,8 +41,10 @@ fn summarize(name: str, values: list[i32]) -> none: Common places where types appear: - function parameters: `fn add(a: i32, b: i32) -> i32:` +- union function parameters: `fn id(value: i32 | i64) -> i32 | i64:` - function return types: `fn test_add() -> none:` - variable declarations: `var total: i32 = 0` +- type aliases: `type Number = i32 | i64` - generic collection annotations: `list[i32]` - shaped 1D tensor annotations: `tensor[i32, 4]` - multidimensional tensor annotations: `tensor[i32, 2, 2]` @@ -61,6 +63,31 @@ function and extern parameter annotations. Static-schema DataFrames can be constructed with `dataframe({...})`, and their columns can be accessed with either `rows.score` or `rows["score"]`. +## Type Aliases And Union Types + +Type aliases are top-level declarations. The `type` word is contextual, so +`type Name = ...` declares an alias while `type(value)` calls the builtin +type-name helper. + +````arx +``` +title: Type alias example +summary: Demonstrates a finite numeric union alias. +``` +type Number = i32 | i64 + +fn identity(value: Number) -> Number: + ``` + title: identity + summary: Returns a numeric union value. + ``` + return value +```` + +Union annotations use `|`. Numeric unions currently lower through a shared +numeric storage type. Runtime tagged unions and runtime narrowing are not part +of the current model. + ## Built-in Type Reference For the catalog of built-in types, aliases, and examples, see @@ -73,3 +100,4 @@ That page covers: - string, character, and temporal types - lists, tensors, dataframes, series, and current limitations - the `cast(value, type)` helper +- the `isinstance(value, type)` and `type(value)` helpers diff --git a/docs/syntax.md b/docs/syntax.md index 59bafee..0f8aad7 100644 --- a/docs/syntax.md +++ b/docs/syntax.md @@ -128,6 +128,7 @@ Contextual keywords: - `as` - `from` +- `type` ### Edge Cases diff --git a/examples/type-aliases.x b/examples/type-aliases.x new file mode 100644 index 0000000..284463d --- /dev/null +++ b/examples/type-aliases.x @@ -0,0 +1,26 @@ +``` +title: Type aliases and type-aware builtins +summary: Demonstrates union aliases, isinstance, type, and cast. +``` + +type Number = i32 | i64 + +fn identity(value: Number) -> Number: + ``` + title: identity + summary: Returns a numeric union value. + ``` + return value + +fn main() -> i32: + ``` + title: main + summary: Calls type-aware builtins with a union alias. + ``` + var value: i64 = identity(5) + var ok: bool = isinstance(value, Number) + var name: str = type(value) + if ok: + return cast(value, i32) + else: + return 1 diff --git a/packages/arx/src/arx/builtins.py b/packages/arx/src/arx/builtins.py index b7f49c2..3bce66e 100644 --- a/packages/arx/src/arx/builtins.py +++ b/packages/arx/src/arx/builtins.py @@ -20,8 +20,10 @@ BUILTIN_CAST = "cast" BUILTIN_DATAFRAME = "dataframe" +BUILTIN_ISINSTANCE = "isinstance" BUILTIN_PRINT = "print" BUILTIN_RANGE = "range" +BUILTIN_TYPE = "type" _GENERATORS_MODULE = f"{BUILTIN_NAMESPACE}.generators" @@ -71,14 +73,18 @@ class AmbientBuiltinBinding: __all__ = [ "BUILTIN_CAST", "BUILTIN_DATAFRAME", + "BUILTIN_ISINSTANCE", "BUILTIN_NAMESPACE", "BUILTIN_PRINT", "BUILTIN_RANGE", "BUILTIN_SOURCE_EXTENSION", + "BUILTIN_TYPE", "AmbientBuiltinBinding", "BuiltinModuleAsset", "build_cast", + "build_isinstance", "build_print", + "build_type_of", "get_ambient_builtin_imports", "get_builtin_source", "is_builtin", @@ -98,7 +104,13 @@ def is_builtin(name: str) -> bool: returns: type: bool """ - return name in {BUILTIN_CAST, BUILTIN_DATAFRAME, BUILTIN_PRINT} + return name in { + BUILTIN_CAST, + BUILTIN_DATAFRAME, + BUILTIN_ISINSTANCE, + BUILTIN_PRINT, + BUILTIN_TYPE, + } def build_cast( @@ -117,6 +129,23 @@ def build_cast( return irx_astx.Cast(value=value, target_type=target_type) +def build_isinstance( + value: astx.Expr, + target_type: astx.DataType, +) -> irx_astx.IsInstanceExpr: + """ + title: Build an IRx IsInstanceExpr node. + parameters: + value: + type: astx.Expr + target_type: + type: astx.DataType + returns: + type: irx_astx.IsInstanceExpr + """ + return irx_astx.IsInstanceExpr(value=value, target_type=target_type) + + def build_print(message: astx.Expr) -> irx_astx.PrintExpr: """ title: Build an IRx PrintExpr node. @@ -129,6 +158,18 @@ def build_print(message: astx.Expr) -> irx_astx.PrintExpr: return irx_astx.PrintExpr(message=message) +def build_type_of(value: astx.Expr) -> irx_astx.TypeOfExpr: + """ + title: Build an IRx TypeOfExpr node. + parameters: + value: + type: astx.Expr + returns: + type: irx_astx.TypeOfExpr + """ + return irx_astx.TypeOfExpr(value=value) + + def is_builtin_module_specifier(specifier: str) -> bool: """ title: Return whether one specifier targets the bundled builtins. diff --git a/packages/arx/src/arx/lexer/syntax.json b/packages/arx/src/arx/lexer/syntax.json index 3a78eab..7bb11de 100644 --- a/packages/arx/src/arx/lexer/syntax.json +++ b/packages/arx/src/arx/lexer/syntax.json @@ -22,7 +22,7 @@ "var", "while" ], - "contextual": ["as", "from", "binary", "unary", "operator"] + "contextual": ["as", "from", "type", "binary", "unary", "operator"] }, "literals": { @@ -147,6 +147,11 @@ "requires_from": true, "allows_trailing_comma": true }, + "type_alias": { + "form": "type Name = type", + "supports_union_rhs": true, + "top_level_only": true + }, "relative_imports": { "leading_dot_tokens": true }, diff --git a/packages/arx/src/arx/parser/__init__.py b/packages/arx/src/arx/parser/__init__.py index de51655..18a2ceb 100644 --- a/packages/arx/src/arx/parser/__init__.py +++ b/packages/arx/src/arx/parser/__init__.py @@ -44,6 +44,8 @@ class Parser( type: list[astx.DataType] template_type_scopes: type: list[dict[str, astx.DataType]] + type_aliases: + type: dict[str, astx.DataType] value_scopes: type: list[set[str]] tokens: @@ -56,6 +58,7 @@ class Parser( tensor_scopes: list[dict[str, TensorBinding | None]] return_type_scopes: list[astx.DataType] template_type_scopes: list[dict[str, astx.DataType]] + type_aliases: dict[str, astx.DataType] value_scopes: list[set[str]] tokens: TokenList diff --git a/packages/arx/src/arx/parser/base.py b/packages/arx/src/arx/parser/base.py index a297370..7008735 100644 --- a/packages/arx/src/arx/parser/base.py +++ b/packages/arx/src/arx/parser/base.py @@ -39,6 +39,8 @@ class ParserMixinBase: type: list[astx.DataType] template_type_scopes: type: list[dict[str, astx.DataType]] + type_aliases: + type: dict[str, astx.DataType] value_scopes: type: list[set[str]] tokens: @@ -53,6 +55,7 @@ class ParserMixinBase: dataframe_scopes: list[dict[str, DataFrameBinding | None]] return_type_scopes: list[astx.DataType] template_type_scopes: list[dict[str, astx.DataType]] + type_aliases: dict[str, astx.DataType] value_scopes: list[set[str]] tokens: TokenList @@ -454,6 +457,20 @@ def parse_type( del allow_template_vars, allow_union, type_context raise NotImplementedError + def is_type_alias_decl_start(self) -> bool: + """ + title: Return whether the current token starts a type alias. + returns: + type: bool + """ + raise NotImplementedError + + def parse_type_alias_decl(self) -> None: + """ + title: Parse one top-level type alias declaration. + """ + raise NotImplementedError + def parse_function( self, template_params: tuple[astx.TemplateParam, ...] = (), diff --git a/packages/arx/src/arx/parser/core.py b/packages/arx/src/arx/parser/core.py index 0fb461f..938a21b 100644 --- a/packages/arx/src/arx/parser/core.py +++ b/packages/arx/src/arx/parser/core.py @@ -41,6 +41,8 @@ class ParserCore(ParserMixinBase): type: list[astx.DataType] template_type_scopes: type: list[dict[str, astx.DataType]] + type_aliases: + type: dict[str, astx.DataType] value_scopes: type: list[set[str]] tokens: @@ -55,6 +57,7 @@ class ParserCore(ParserMixinBase): dataframe_scopes: list[dict[str, DataFrameBinding | None]] return_type_scopes: list[astx.DataType] template_type_scopes: list[dict[str, astx.DataType]] + type_aliases: dict[str, astx.DataType] value_scopes: list[set[str]] tokens: TokenList @@ -89,6 +92,7 @@ def __init__(self, tokens: TokenList = TokenList([])) -> None: self.dataframe_scopes = [{}] self.return_type_scopes = [] self.template_type_scopes = [] + self.type_aliases = {} self.value_scopes = [set()] self.tokens = tokens @@ -103,6 +107,7 @@ def clean(self) -> None: self.dataframe_scopes = [{}] self.return_type_scopes = [] self.template_type_scopes = [] + self.type_aliases = {} self.value_scopes = [set()] self.tokens = TokenList([]) @@ -197,6 +202,11 @@ def parse( allow_module_docstring = False continue + if self.is_type_alias_decl_start(): + self.parse_type_alias_decl() + allow_module_docstring = False + continue + if self.tokens.cur_tok.kind == TokenKind.kw_function: tree.nodes.append(self.parse_function()) allow_module_docstring = False diff --git a/packages/arx/src/arx/parser/declarations.py b/packages/arx/src/arx/parser/declarations.py index edffa14..801bfbc 100644 --- a/packages/arx/src/arx/parser/declarations.py +++ b/packages/arx/src/arx/parser/declarations.py @@ -574,7 +574,7 @@ def parse_method_signature( self._consume_operator(":") param_type = self.parse_type( - type_context=TypeUseContext.PARAMETER + allow_union=True, type_context=TypeUseContext.PARAMETER ) self._append_argument( args, @@ -599,7 +599,10 @@ def parse_method_signature( ) self._consume_operator("->") - return_type = self.parse_type(type_context=TypeUseContext.RETURN) + return_type = self.parse_type( + allow_union=True, + type_context=TypeUseContext.RETURN, + ) return ( astx.FunctionPrototype( method_name, @@ -1079,7 +1082,7 @@ def parse_prototype(self, expect_colon: bool) -> astx.FunctionPrototype: self._consume_operator(":") arg_type = self.parse_type( - type_context=TypeUseContext.PARAMETER + allow_union=True, type_context=TypeUseContext.PARAMETER ) self._append_argument(args, arg_name, arg_type, arg_loc) @@ -1098,7 +1101,7 @@ def parse_prototype(self, expect_colon: bool) -> astx.FunctionPrototype: ) self._consume_operator("->") ret_type: astx.DataType = self.parse_type( - type_context=TypeUseContext.RETURN + allow_union=True, type_context=TypeUseContext.RETURN ) if expect_colon: diff --git a/packages/arx/src/arx/parser/expressions.py b/packages/arx/src/arx/parser/expressions.py index c2fe067..c57cd7c 100644 --- a/packages/arx/src/arx/parser/expressions.py +++ b/packages/arx/src/arx/parser/expressions.py @@ -351,13 +351,34 @@ def parse_identifier_expr(self) -> astx.AST: value_expr = self.parse_expression() self._consume_operator(",") target_type = self.parse_type( - type_context=TypeUseContext.EXPRESSION + allow_union=True, type_context=TypeUseContext.EXPRESSION ) + if isinstance(target_type, astx.UnionType): + raise ParserException( + "Builtin 'cast' does not support union target types yet." + ) self._consume_operator(")") return builtins.build_cast( cast(astx.DataType, value_expr), target_type ) + if id_name == builtins.BUILTIN_ISINSTANCE: + if template_args is not None: + raise ParserException( + f"Builtin '{id_name}' does not accept template arguments." + ) + value_expr = self.parse_expression() + self._consume_operator(",") + target_type = self.parse_type( + allow_union=True, + type_context=TypeUseContext.EXPRESSION, + ) + self._consume_operator(")") + return builtins.build_isinstance( + cast(astx.Expr, value_expr), + target_type, + ) + if id_name == builtins.BUILTIN_PRINT: if template_args is not None: raise ParserException( @@ -367,6 +388,15 @@ def parse_identifier_expr(self) -> astx.AST: self._consume_operator(")") return builtins.build_print(cast(astx.Expr, message)) + if id_name == builtins.BUILTIN_TYPE: + if template_args is not None: + raise ParserException( + f"Builtin '{id_name}' does not accept template arguments." + ) + value_expr = self.parse_expression() + self._consume_operator(")") + return builtins.build_type_of(cast(astx.Expr, value_expr)) + if id_name in {"datetime", "timestamp"}: if template_args is not None: raise ParserException( diff --git a/packages/arx/src/arx/parser/state.py b/packages/arx/src/arx/parser/state.py index 853f0cd..5dba825 100644 --- a/packages/arx/src/arx/parser/state.py +++ b/packages/arx/src/arx/parser/state.py @@ -32,6 +32,7 @@ class TypeUseContext(Enum): INLINE_VARIABLE = "inline variable" FIELD = "field" EXPRESSION = "expression" + TYPE_ALIAS = "type alias" TEMPLATE_BOUND = "template bound" TEMPLATE_ARGUMENT = "template argument" NESTED = "nested type" diff --git a/packages/arx/src/arx/parser/types.py b/packages/arx/src/arx/parser/types.py index 3fef101..bdace3d 100644 --- a/packages/arx/src/arx/parser/types.py +++ b/packages/arx/src/arx/parser/types.py @@ -6,10 +6,13 @@ from __future__ import annotations +import copy + from typing import cast import astx +from arx import builtins from arx.dataframe import ( dataframe_type, is_dataframe_type, @@ -28,12 +31,118 @@ tensor_type, ) +_BUILTIN_TYPE_MAP: dict[str, astx.DataType] = { + "i8": astx.Int8(), + "i16": astx.Int16(), + "i32": astx.Int32(), + "i64": astx.Int64(), + "int8": astx.Int8(), + "int16": astx.Int16(), + "int32": astx.Int32(), + "int64": astx.Int64(), + "f16": astx.Float16(), + "f32": astx.Float32(), + "f64": astx.Float64(), + "float16": astx.Float16(), + "float32": astx.Float32(), + "float64": astx.Float64(), + "bool": astx.Boolean(), + "boolean": astx.Boolean(), + "none": astx.NoneType(), + "str": astx.String(), + "string": astx.String(), + "char": astx.Int8(), + "datetime": astx.DateTime(), + "timestamp": astx.Timestamp(), + "date": astx.Date(), + "time": astx.Time(), +} + +_BUILTIN_TYPE_NAMES = frozenset(_BUILTIN_TYPE_MAP) | frozenset( + {"dataframe", "list", "series", "tensor"} +) + class TypeParserMixin(ParserMixinBase): """ title: Type parser mixin. """ + def is_type_alias_decl_start(self) -> bool: + """ + title: Return whether the current token starts a type alias. + returns: + type: bool + """ + return ( + self._is_identifier_value(builtins.BUILTIN_TYPE) + and self._peek_token().kind == TokenKind.identifier + ) + + def _clone_type_for_alias(self, alias_name: str) -> astx.DataType: + """ + title: Return a cloned target type for one alias reference. + parameters: + alias_name: + type: str + returns: + type: astx.DataType + """ + type_ = copy.deepcopy(self.type_aliases[alias_name]) + if isinstance(type_, astx.UnionType): + type_.alias_name = type_.alias_name or alias_name + return type_ + setattr(type_, "alias_name", alias_name) + return type_ + + def _validate_type_alias_name(self, alias_name: str) -> None: + """ + title: Validate one type alias declaration name. + parameters: + alias_name: + type: str + """ + if alias_name in self.type_aliases: + raise ParserException( + f"Parser: Duplicate type alias '{alias_name}'." + ) + if alias_name in _BUILTIN_TYPE_NAMES: + raise ParserException( + f"Parser: Type alias '{alias_name}' shadows a built-in type." + ) + if builtins.is_builtin(alias_name): + raise ParserException( + f"Parser: Type alias '{alias_name}' shadows a built-in." + ) + if alias_name in self.known_class_names: + raise ParserException( + f"Parser: Type alias '{alias_name}' shadows a class." + ) + + def parse_type_alias_decl(self) -> None: + """ + title: Parse one top-level type alias declaration. + """ + self._consume_identifier_value(builtins.BUILTIN_TYPE) + if self.tokens.cur_tok.kind != TokenKind.identifier: + raise ParserException("Parser: Expected type alias name.") + + alias_name = cast(str, self.tokens.cur_tok.value) + self._validate_type_alias_name(alias_name) + self.tokens.get_next_token() # eat alias name + self._consume_operator("=") + + alias_type = self.parse_type( + allow_template_vars=False, + allow_union=True, + type_context=TypeUseContext.TYPE_ALIAS, + ) + if isinstance(alias_type, astx.UnionType): + alias_type.alias_name = alias_name + else: + setattr(alias_type, "alias_name", alias_name) + self.type_aliases[alias_name] = alias_type + def _consume_runtime_shape_marker(self) -> None: """ title: Consume one runtime-shape ellipsis marker. @@ -291,41 +400,16 @@ def parse_type( except ValueError as err: raise ParserException(str(err)) from err else: - type_map: dict[str, astx.DataType] = { - "i8": astx.Int8(), - "i16": astx.Int16(), - "i32": astx.Int32(), - "i64": astx.Int64(), - "int8": astx.Int8(), - "int16": astx.Int16(), - "int32": astx.Int32(), - "int64": astx.Int64(), - "f16": astx.Float16(), - "f32": astx.Float32(), - "f64": astx.Float64(), - "float16": astx.Float16(), - "float32": astx.Float32(), - "float64": astx.Float64(), - "bool": astx.Boolean(), - "boolean": astx.Boolean(), - "none": astx.NoneType(), - "str": astx.String(), - "string": astx.String(), - "char": astx.Int8(), - "datetime": astx.DateTime(), - "timestamp": astx.Timestamp(), - "date": astx.Date(), - "time": astx.Time(), - } - self.tokens.get_next_token() # eat type identifier - if type_name in type_map: - type_ = type_map[type_name] + if type_name in _BUILTIN_TYPE_MAP: + type_ = copy.deepcopy(_BUILTIN_TYPE_MAP[type_name]) elif template_bound is not None: type_ = astx.TemplateTypeVar( type_name, bound=template_bound, ) + elif type_name in self.type_aliases: + type_ = self._clone_type_for_alias(type_name) elif type_name in self.known_class_names: type_ = astx.ClassType(type_name) else: @@ -336,15 +420,21 @@ def parse_type( if not allow_union or not self._is_operator("|"): return type_ - members = [type_] + members: list[astx.DataType] = [] + if isinstance(type_, astx.UnionType): + members.extend(type_.members) + else: + members.append(type_) while self._is_operator("|"): self._consume_operator("|") - members.append( - self.parse_type( - allow_template_vars=allow_template_vars, - allow_union=False, - type_context=type_context, - ) + member_type = self.parse_type( + allow_template_vars=allow_template_vars, + allow_union=False, + type_context=type_context, ) + if isinstance(member_type, astx.UnionType): + members.extend(member_type.members) + else: + members.append(member_type) return astx.UnionType(members) diff --git a/packages/arx/tests/python/test_app_paths.py b/packages/arx/tests/python/test_app_paths.py index e93e9d9..636f1d0 100644 --- a/packages/arx/tests/python/test_app_paths.py +++ b/packages/arx/tests/python/test_app_paths.py @@ -33,13 +33,22 @@ def test_builtins_helpers() -> None: title: Test builtins helper functions. """ assert builtins.is_builtin("cast") + assert builtins.is_builtin("isinstance") assert builtins.is_builtin("print") + assert builtins.is_builtin("type") assert not builtins.is_builtin("custom_fn") cast_node = builtins.build_cast(astx.LiteralInt32(1), astx.Float32()) + isinstance_node = builtins.build_isinstance( + astx.LiteralInt32(1), + astx.Int32(), + ) print_node = builtins.build_print(astx.LiteralString("hello")) + type_node = builtins.build_type_of(astx.LiteralInt32(1)) assert isinstance(cast_node, irx_astx.Cast) + assert isinstance(isinstance_node, irx_astx.IsInstanceExpr) assert isinstance(print_node, irx_astx.PrintExpr) + assert isinstance(type_node, irx_astx.TypeOfExpr) def test_custom_exceptions_prefixes() -> None: diff --git a/packages/arx/tests/python/test_codegen_ast_output.py b/packages/arx/tests/python/test_codegen_ast_output.py index dd91d60..f7263c8 100644 --- a/packages/arx/tests/python/test_codegen_ast_output.py +++ b/packages/arx/tests/python/test_codegen_ast_output.py @@ -78,6 +78,24 @@ return 0 """ ).lstrip(), + dedent( + """ + type Number = i32 | i64 + + fn identity(value: Number) -> Number: + return value + + fn main() -> i32: + var value: i64 = identity(5) + var ok: bool = isinstance(value, Number) + var name: str = type(value) + print(name) + if ok: + return cast(value, i32) + else: + return 1 + """ + ).lstrip(), dedent( """ class BaseCounter: diff --git a/packages/arx/tests/python/test_codegen_file_object.py b/packages/arx/tests/python/test_codegen_file_object.py index 4fefee8..10b03c6 100644 --- a/packages/arx/tests/python/test_codegen_file_object.py +++ b/packages/arx/tests/python/test_codegen_file_object.py @@ -408,6 +408,97 @@ class Math: assert result.stderr == "" +@pytest.mark.skipif(not HAS_CLANG, reason="clang is required for object build") +def test_type_alias_union_and_type_builtins_build_and_run( + tmp_path: Path, +) -> None: + """ + title: Type aliases, unions, and type-aware builtins should build. + parameters: + tmp_path: + type: Path + """ + module_ast = _parse_min_module( + dedent( + """ + type Number = i32 | i64 + type A = i32 | i64 + type B = i32 | i64 + + fn identity(value: Number) -> Number: + return value + + fn to_b(value: A) -> B: + return value + + fn to_explicit(value: B) -> i32 | i64: + return value + + fn main() -> int32: + var value: i64 = to_explicit(to_b(identity(5))) + var ok: bool = isinstance(value, Number) + var name: str = type(value) + if ok: + return cast(value, int32) + else: + return 1 + """ + ).lstrip() + ) + + bin_path = tmp_path / "type_alias_program" + ArxBuilder().build(module_ast, str(bin_path)) + + result = subprocess.run( + [str(bin_path)], + check=False, + capture_output=True, + text=True, + ) + + assert result.returncode == 5 + assert result.stdout == "" + assert result.stderr == "" + + +@pytest.mark.skipif(not HAS_CLANG, reason="clang is required for object build") +def test_isinstance_uses_membership_not_numeric_widening( + tmp_path: Path, +) -> None: + """ + title: isinstance should not treat assignable numeric types as identical. + parameters: + tmp_path: + type: Path + """ + module_ast = _parse_min_module( + dedent( + """ + fn main() -> int32: + var value: i32 = 1 + if isinstance(value, i64): + return 1 + else: + return 0 + """ + ).lstrip() + ) + + bin_path = tmp_path / "isinstance_membership_program" + ArxBuilder().build(module_ast, str(bin_path)) + + result = subprocess.run( + [str(bin_path)], + check=False, + capture_output=True, + text=True, + ) + + assert result.returncode == 0 + assert result.stdout == "" + assert result.stderr == "" + + def test_build_without_link_writes_object_file( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: diff --git a/packages/arx/tests/python/test_parser.py b/packages/arx/tests/python/test_parser.py index bc020d2..9b37aa5 100644 --- a/packages/arx/tests/python/test_parser.py +++ b/packages/arx/tests/python/test_parser.py @@ -750,6 +750,76 @@ def test_parse_builtin_cast_and_print() -> None: assert isinstance(fn.body.nodes[1], irx_astx.PrintExpr) +def test_parse_type_alias_and_union_signature() -> None: + """ + title: Type aliases should resolve in union signatures. + """ + tree = _parse_module( + "type Number = i32 | i64\n" + "fn widen(x: Number) -> i32 | i64:\n" + " return x\n" + ) + + fn = tree.nodes[0] + assert isinstance(fn, astx.FunctionDef) + arg_type = fn.prototype.args[0].type_ + ret_type = fn.prototype.return_type + assert isinstance(arg_type, astx.UnionType) + assert arg_type.alias_name == "Number" + assert [type(member) for member in arg_type.members] == [ + astx.Int32, + astx.Int64, + ] + assert isinstance(ret_type, astx.UnionType) + + +def test_parse_type_alias_with_cast_isinstance_and_typeof() -> None: + """ + title: Type aliases should work with type-aware builtins. + """ + tree = _parse_module( + "type Int = i32\n" + "fn main() -> i32:\n" + " var x: Int = cast(0.0, Int)\n" + " var ok: bool = isinstance(x, Int)\n" + " var name: str = type(x)\n" + " return x\n" + ) + + fn = tree.nodes[0] + assert isinstance(fn, astx.FunctionDef) + cast_decl = fn.body.nodes[0] + assert isinstance(cast_decl, astx.VariableDeclaration) + assert isinstance(cast_decl.value, irx_astx.Cast) + isinstance_decl = fn.body.nodes[1] + assert isinstance(isinstance_decl, astx.VariableDeclaration) + assert isinstance(isinstance_decl.value, irx_astx.IsInstanceExpr) + type_decl = fn.body.nodes[2] + assert isinstance(type_decl, astx.VariableDeclaration) + assert isinstance(type_decl.value, irx_astx.TypeOfExpr) + + +def test_parse_type_alias_rejects_builtin_shadowing() -> None: + """ + title: Type aliases should not shadow parser-level builtins. + """ + with pytest.raises(ParserException, match="shadows a built-in"): + _parse_module("type cast = i32\n") + + +def test_parse_cast_rejects_union_target() -> None: + """ + title: Cast should reject union target types until runtime unions exist. + """ + with pytest.raises(ParserException, match="union target types"): + _parse_module( + "type Number = i32 | i64\n" + "fn main() -> i32:\n" + " var x: Number = cast(0.0, Number)\n" + " return 0\n" + ) + + def test_parse_block_with_comment_and_blank_lines() -> None: """ title: Test block parsing across comment/blank lines. diff --git a/packages/astx/src/astx/__init__.py b/packages/astx/src/astx/__init__.py index f30dcc8..5aa83aa 100644 --- a/packages/astx/src/astx/__init__.py +++ b/packages/astx/src/astx/__init__.py @@ -251,7 +251,9 @@ from astx.system import ( AssertStmt, Cast, + IsInstanceExpr, PrintExpr, + TypeOfExpr, ) from astx.templates import ( TemplateParam, @@ -464,6 +466,7 @@ def get_version() -> str: "Int32", "Int64", "Integer", + "IsInstanceExpr", "LambdaExpr", "LeBinOp", "ListAppend", @@ -567,6 +570,7 @@ def get_version() -> str: "Timestamp", "TupleType", "TypeCastExpr", + "TypeOfExpr", "UInt8", "UInt16", "UInt32", diff --git a/packages/astx/src/astx/system.py b/packages/astx/src/astx/system.py index 412b31e..2c40bda 100644 --- a/packages/astx/src/astx/system.py +++ b/packages/astx/src/astx/system.py @@ -169,4 +169,96 @@ def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: return self._prepare_struct(key, value, simplified) -__all__ = ["AssertStmt", "Cast", "PrintExpr"] +@typechecked +class IsInstanceExpr(astx.Expr): + """ + title: IsInstanceExpr AST class. + summary: Represent a type-membership check against one target type. + attributes: + value: + type: astx.Expr + target_type: + type: astx.DataType + """ + + value: astx.Expr + target_type: astx.DataType + + def __init__( + self, + value: astx.Expr, + target_type: astx.DataType, + ) -> None: + """ + title: Initialize IsInstanceExpr. + parameters: + value: + type: astx.Expr + target_type: + type: astx.DataType + """ + super().__init__() + self.value = value + self.target_type = target_type + + def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: + """ + title: Return the AST structure of the isinstance expression. + parameters: + simplified: + type: bool + returns: + type: astx.base.ReprStruct + """ + value: astx.base.DictDataTypesStruct = { + "value": self.value.get_struct(simplified), + "target_type": self.target_type.get_struct(simplified), + } + return self._prepare_struct("IsInstanceExpr", value, simplified) + + +@typechecked +class TypeOfExpr(astx.Expr): + """ + title: TypeOfExpr AST class. + summary: Represent an expression that produces a value's type name. + attributes: + value: + type: astx.Expr + """ + + value: astx.Expr + + def __init__(self, value: astx.Expr) -> None: + """ + title: Initialize TypeOfExpr. + parameters: + value: + type: astx.Expr + """ + super().__init__() + self.value = value + + def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: + """ + title: Return the AST structure of the type-of expression. + parameters: + simplified: + type: bool + returns: + type: astx.base.ReprStruct + """ + return self._prepare_struct( + "TypeOfExpr", + self.value.get_struct(simplified), + simplified, + ) + + +__all__ = [ + "AssertStmt", + "Cast", + "IsInstanceExpr", + "PrintExpr", + "TypeOfExpr", +] diff --git a/packages/irx/src/irx/analysis/handlers/_expressions/operators.py b/packages/irx/src/irx/analysis/handlers/_expressions/operators.py index cbdaff8..8ee10a1 100644 --- a/packages/irx/src/irx/analysis/handlers/_expressions/operators.py +++ b/packages/irx/src/irx/analysis/handlers/_expressions/operators.py @@ -32,6 +32,7 @@ is_integer_type, is_numeric_type, is_string_type, + is_type_member, ) from irx.analysis.typing import binary_result_type, unary_result_type from irx.analysis.validation import validate_assignment, validate_cast @@ -246,6 +247,28 @@ def visit(self, node: astx.Cast) -> None: ) self._set_type(node, target_type) + @SemanticAnalyzerCore.visit.dispatch + def visit(self, node: astx.IsInstanceExpr) -> None: + """ + title: Visit IsInstanceExpr nodes. + parameters: + node: + type: astx.IsInstanceExpr + """ + self.visit(node.value) + self._resolve_declared_type(node.target_type, node=node) + if not self._require_value_expression( + node.value, + context="IsInstanceExpr", + ): + self._set_type(node, astx.Boolean()) + setattr(node, "static_result", False) + return + value_type = self._expr_type(node.value) + static_result = is_type_member(node.target_type, value_type) + setattr(node, "static_result", static_result) + self._set_type(node, astx.Boolean()) + @SemanticAnalyzerCore.visit.dispatch def visit(self, node: astx.PrintExpr) -> None: """ @@ -275,3 +298,20 @@ def visit(self, node: astx.PrintExpr) -> None: code=DiagnosticCodes.SEMANTIC_TYPE_MISMATCH, ) self._set_type(node, astx.Int32()) + + @SemanticAnalyzerCore.visit.dispatch + def visit(self, node: astx.TypeOfExpr) -> None: + """ + title: Visit TypeOfExpr nodes. + parameters: + node: + type: astx.TypeOfExpr + """ + self.visit(node.value) + if not self._require_value_expression( + node.value, + context="TypeOfExpr", + ): + self._set_type(node, astx.String()) + return + self._set_type(node, astx.String()) diff --git a/packages/irx/src/irx/analysis/types.py b/packages/irx/src/irx/analysis/types.py index 72f4b93..7bbb3a7 100644 --- a/packages/irx/src/irx/analysis/types.py +++ b/packages/irx/src/irx/analysis/types.py @@ -217,38 +217,61 @@ def clone_type(type_: astx.DataType) -> astx.DataType: returns: type: astx.DataType """ + + def with_alias(cloned: astx.DataType) -> astx.DataType: + """ + title: Preserve alias metadata on a cloned type. + parameters: + cloned: + type: astx.DataType + returns: + type: astx.DataType + """ + alias_name = getattr(type_, "alias_name", None) + if isinstance(alias_name, str): + setattr(cloned, "alias_name", alias_name) + return cloned + if isinstance(type_, astx.UnionType): return astx.UnionType( tuple(clone_type(member) for member in type_.members), alias_name=type_.alias_name, ) if isinstance(type_, astx.TemplateTypeVar): - return astx.TemplateTypeVar( - type_.name, - bound=clone_type(type_.bound), + return with_alias( + astx.TemplateTypeVar( + type_.name, + bound=clone_type(type_.bound), + ) ) if isinstance(type_, astx.GeneratorType): - return astx.GeneratorType(clone_type(type_.yield_type)) + return with_alias(astx.GeneratorType(clone_type(type_.yield_type))) if isinstance(type_, astx.StructType): - return astx.StructType( - type_.name, - resolved_name=type_.resolved_name, - module_key=type_.module_key, - qualified_name=type_.qualified_name, + return with_alias( + astx.StructType( + type_.name, + resolved_name=type_.resolved_name, + module_key=type_.module_key, + qualified_name=type_.qualified_name, + ) ) if isinstance(type_, astx.ClassType): - return astx.ClassType( - type_.name, - resolved_name=type_.resolved_name, - module_key=type_.module_key, - qualified_name=type_.qualified_name, - ancestor_qualified_names=type_.ancestor_qualified_names, + return with_alias( + astx.ClassType( + type_.name, + resolved_name=type_.resolved_name, + module_key=type_.module_key, + qualified_name=type_.qualified_name, + ancestor_qualified_names=type_.ancestor_qualified_names, + ) ) if isinstance(type_, astx.NamespaceType): - return astx.NamespaceType( - type_.namespace_key, - namespace_kind=type_.namespace_kind, - display_name=type_.display_name, + return with_alias( + astx.NamespaceType( + type_.namespace_key, + namespace_kind=type_.namespace_kind, + display_name=type_.display_name, + ) ) if isinstance(type_, astx.PointerType): pointee_type = ( @@ -256,75 +279,85 @@ def clone_type(type_: astx.DataType) -> astx.DataType: if type_.pointee_type is not None else None ) - return astx.PointerType(pointee_type) + return with_alias(astx.PointerType(pointee_type)) if isinstance(type_, astx.ListType): - return astx.ListType( - [ - clone_type(cast(astx.DataType, element_type)) - for element_type in type_.element_types - ], - size=type_.size, + return with_alias( + astx.ListType( + [ + clone_type(cast(astx.DataType, element_type)) + for element_type in type_.element_types + ], + size=type_.size, + ) ) if isinstance(type_, astx.TupleType): - return astx.TupleType( - [ - clone_type(cast(astx.DataType, element_type)) - for element_type in type_.element_types - ] + return with_alias( + astx.TupleType( + [ + clone_type(cast(astx.DataType, element_type)) + for element_type in type_.element_types + ] + ) ) if isinstance(type_, astx.SetType): - return astx.SetType( - clone_type(cast(astx.DataType, type_.element_type)) + return with_alias( + astx.SetType(clone_type(cast(astx.DataType, type_.element_type))) ) if isinstance(type_, astx.DictType): - return astx.DictType( - clone_type(cast(astx.DataType, type_.key_type)), - clone_type(cast(astx.DataType, type_.value_type)), + return with_alias( + astx.DictType( + clone_type(cast(astx.DataType, type_.key_type)), + clone_type(cast(astx.DataType, type_.value_type)), + ) ) if isinstance(type_, astx.BufferOwnerType): - return type_.__class__() + return with_alias(type_.__class__()) if isinstance(type_, astx.OpaqueHandleType): - return astx.OpaqueHandleType(type_.handle_name) + return with_alias(astx.OpaqueHandleType(type_.handle_name)) if isinstance(type_, astx.BufferViewType): element_type = ( clone_type(type_.element_type) if type_.element_type is not None else None ) - return astx.BufferViewType(element_type) + return with_alias(astx.BufferViewType(element_type)) if isinstance(type_, astx.TensorType): element_type = ( clone_type(type_.element_type) if type_.element_type is not None else None ) - return astx.TensorType(element_type, shape=type_.shape) + return with_alias(astx.TensorType(element_type, shape=type_.shape)) if isinstance(type_, astx.SeriesType): element_type = ( clone_type(type_.element_type) if type_.element_type is not None else None ) - return astx.SeriesType( - element_type, - nullable=type_.nullable, - size=type_.size, + return with_alias( + astx.SeriesType( + element_type, + nullable=type_.nullable, + size=type_.size, + ) ) if isinstance(type_, astx.DataFrameType): if type_.columns is None: - return astx.DataFrameType(row_count=type_.row_count) - return astx.DataFrameType( - tuple( - astx.DataFrameColumn( - column.name, - clone_type(column.type_), - nullable=column.nullable, - ) - for column in type_.columns - ), - row_count=type_.row_count, + return with_alias(astx.DataFrameType(row_count=type_.row_count)) + return with_alias( + astx.DataFrameType( + tuple( + astx.DataFrameColumn( + column.name, + clone_type(column.type_), + nullable=column.nullable, + ) + for column in type_.columns + ), + row_count=type_.row_count, + ) ) - return type_.__class__() + return with_alias(type_.__class__()) @public @@ -340,6 +373,9 @@ def display_type_name(type_: astx.DataType | None) -> str: """ if type_ is None: return "" + alias_name = getattr(type_, "alias_name", None) + if isinstance(alias_name, str): + return alias_name if isinstance(type_, astx.UnionType): if type_.alias_name is not None: return type_.alias_name @@ -444,14 +480,22 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: if lhs is None or rhs is None: return False if isinstance(lhs, astx.UnionType) and isinstance(rhs, astx.UnionType): - if lhs.alias_name != rhs.alias_name: - return False if len(lhs.members) != len(rhs.members): return False - return all( - same_type(left_member, right_member) - for left_member, right_member in zip(lhs.members, rhs.members) - ) + unmatched_members = list(rhs.members) + for left_member in lhs.members: + matched_index = next( + ( + index + for index, right_member in enumerate(unmatched_members) + if same_type(left_member, right_member) + ), + None, + ) + if matched_index is None: + return False + unmatched_members.pop(matched_index) + return True if isinstance(lhs, astx.TemplateTypeVar) and isinstance( rhs, astx.TemplateTypeVar, @@ -700,6 +744,42 @@ def is_none_type(type_: astx.DataType | None) -> bool: return isinstance(type_, astx.NoneType) +@public +@typechecked +def is_type_member( + target: astx.DataType | None, + value: astx.DataType | None, +) -> bool: + """ + title: Return whether a value type is a member of a target type. + parameters: + target: + type: astx.DataType | None + value: + type: astx.DataType | None + returns: + type: bool + """ + if target is None or value is None: + return False + if isinstance(target, astx.UnionType) and isinstance( + value, + astx.UnionType, + ): + return all( + any( + is_type_member(target_member, value_member) + for target_member in target.members + ) + for value_member in value.members + ) + if isinstance(target, astx.UnionType): + return any(is_type_member(member, value) for member in target.members) + if isinstance(value, astx.UnionType): + return all(is_type_member(target, member) for member in value.members) + return same_type(target, value) + + @public @typechecked def bit_width(type_: astx.DataType | None) -> int: @@ -957,6 +1037,17 @@ def is_assignable( return True if not _metadata_assignment_compatible(target, value): return False + if isinstance(target, astx.UnionType) and isinstance( + value, + astx.UnionType, + ): + return all( + any( + is_assignable(target_member, value_member) + for target_member in target.members + ) + for value_member in value.members + ) if same_type(target, value): return True if isinstance(target, astx.UnionType): diff --git a/packages/irx/src/irx/builder/core.py b/packages/irx/src/irx/builder/core.py index dddcac7..40c2cbc 100644 --- a/packages/irx/src/irx/builder/core.py +++ b/packages/irx/src/irx/builder/core.py @@ -1085,6 +1085,11 @@ def _llvm_type_for_ast_type( """ if type_ is None: return None + if isinstance(type_, astx.UnionType): + storage_type = self._union_storage_ast_type(type_) + if storage_type is None: + return None + return self._llvm_type_for_ast_type(storage_type) if isinstance(type_, astx.NamespaceType): return self._llvm.OPAQUE_POINTER_TYPE if isinstance(type_, astx.GeneratorType): @@ -1148,6 +1153,41 @@ def _llvm_type_for_ast_type( type_name = type_.__class__.__name__.lower() return self._llvm.get_data_type(type_name) + def _union_storage_ast_type( + self, + type_: astx.UnionType, + ) -> astx.DataType | None: + """ + title: Return one scalar storage type for a finite union. + parameters: + type_: + type: astx.UnionType + returns: + type: astx.DataType | None + """ + members = tuple(type_.members) + if not members: + return None + + storage_type: astx.DataType | None = members[0] + for member in members[1:]: + numeric_type = common_numeric_type(storage_type, member) + if numeric_type is not None: + storage_type = numeric_type + continue + + storage_llvm_type = self._llvm_type_for_ast_type(storage_type) + member_llvm_type = self._llvm_type_for_ast_type(member) + if ( + storage_llvm_type is not None + and member_llvm_type is not None + and storage_llvm_type == member_llvm_type + ): + continue + return None + + return storage_type + def _resolved_class_receiver_field_address( self, *, @@ -1453,6 +1493,13 @@ def _cast_ast_value( if source_type is None or target_type is None: return value + if isinstance(source_type, astx.UnionType): + source_type = self._union_storage_ast_type(source_type) + if isinstance(target_type, astx.UnionType): + target_type = self._union_storage_ast_type(target_type) + if source_type is None or target_type is None: + return value + target_llvm_type = self._llvm_type_for_ast_type(target_type) if target_llvm_type is None: return value diff --git a/packages/irx/src/irx/builder/lowering/system.py b/packages/irx/src/irx/builder/lowering/system.py index 037496f..d861258 100644 --- a/packages/irx/src/irx/builder/lowering/system.py +++ b/packages/irx/src/irx/builder/lowering/system.py @@ -4,11 +4,17 @@ title: System/runtime visitor mixins for llvmliteir. """ +from typing import Any, cast + import astx from llvmlite import ir -from irx.analysis.types import is_boolean_type, is_unsigned_type +from irx.analysis.types import ( + display_type_name, + is_boolean_type, + is_unsigned_type, +) from irx.builder.core import VisitorCore from irx.builder.protocols import VisitorMixinBase from irx.builder.runtime import safe_pop @@ -78,6 +84,21 @@ def visit(self, node: astx.Cast) -> None: ) self.result_stack.append(result) + @VisitorCore.visit.dispatch + def visit(self, node: astx.IsInstanceExpr) -> None: + """ + title: Visit IsInstanceExpr nodes. + parameters: + node: + type: astx.IsInstanceExpr + """ + self.visit_child(node.value) + _ = safe_pop(self.result_stack) + static_result = cast(bool, getattr(node, "static_result", False)) + self.result_stack.append( + ir.Constant(self._llvm.BOOLEAN_TYPE, int(static_result)) + ) + @VisitorCore.visit.dispatch def visit(self, node: astx.PrintExpr) -> None: """ @@ -126,3 +147,20 @@ def visit(self, node: astx.PrintExpr) -> None: puts_fn = self.require_runtime_symbol("libc", "puts") self._llvm.ir_builder.call(puts_fn, [ptr]) self.result_stack.append(ir.Constant(self._llvm.INT32_TYPE, 0)) + + @VisitorCore.visit.dispatch + def visit(self, node: astx.TypeOfExpr) -> None: + """ + title: Visit TypeOfExpr nodes. + parameters: + node: + type: astx.TypeOfExpr + """ + self.visit_child(node.value) + _ = safe_pop(self.result_stack) + type_name = display_type_name(self._resolved_ast_type(node.value)) + pointer = cast(Any, self)._constant_c_string_pointer( + type_name, + name_hint="type_name", + ) + self.result_stack.append(pointer) diff --git a/packages/irx/tests/analysis/test_sized_type_metadata.py b/packages/irx/tests/analysis/test_sized_type_metadata.py index 8067bc6..42e58d6 100644 --- a/packages/irx/tests/analysis/test_sized_type_metadata.py +++ b/packages/irx/tests/analysis/test_sized_type_metadata.py @@ -10,6 +10,7 @@ clone_type, display_type_name, is_assignable, + is_type_member, requires_shape_check, requires_size_check, same_type, @@ -102,6 +103,63 @@ def test_same_type_allows_unconstrained_tensor_shape_wildcard() -> None: ) +def test_union_same_type_ignores_alias_names() -> None: + """ + title: Union same_type treats aliases as structural names. + """ + left = astx.UnionType( + (astx.Int32(), astx.Int64()), + alias_name="A", + ) + right = astx.UnionType( + (astx.Int64(), astx.Int32()), + alias_name="B", + ) + + assert same_type(left, right) + + +def test_union_assignability_is_structural() -> None: + """ + title: Union assignment accepts structural aliases and safe subsets. + """ + alias_a = astx.UnionType( + (astx.Int32(), astx.Int64()), + alias_name="A", + ) + alias_b = astx.UnionType( + (astx.Int32(), astx.Int64()), + alias_name="B", + ) + explicit = astx.UnionType((astx.Int32(), astx.Int64())) + subset = astx.UnionType((astx.Int32(),)) + + assert is_assignable(alias_b, alias_a) + assert is_assignable(explicit, alias_a) + assert is_assignable(alias_a, explicit) + assert is_assignable(alias_a, subset) + assert not is_assignable(subset, alias_a) + + +def test_is_type_member_uses_exact_structural_membership() -> None: + """ + title: Type membership avoids assignment-only numeric widening. + """ + alias_a = astx.UnionType( + (astx.Int32(), astx.Int64()), + alias_name="A", + ) + alias_b = astx.UnionType( + (astx.Int64(), astx.Int32()), + alias_name="B", + ) + + assert is_type_member(alias_a, astx.Int32()) + assert is_type_member(alias_a, alias_b) + assert not is_type_member(astx.Int64(), astx.Int32()) + assert not is_type_member(astx.Int64(), alias_a) + + def test_tensor_assignability_checks_known_shapes() -> None: """ title: Tensor assignability rejects unchecked shape narrowing.