diff --git a/benches/programs/fibonacci_typed.ndc b/benches/programs/fibonacci_typed.ndc new file mode 100644 index 00000000..8b904a75 --- /dev/null +++ b/benches/programs/fibonacci_typed.ndc @@ -0,0 +1,5 @@ +fn fib(n: Int) -> Int { + if n <= 1 { 1 } else { fib(n - 2) + fib(n - 1) } +} + +fib(26); diff --git a/manual/src/reference/types.md b/manual/src/reference/types.md index 4ce13150..f6135f42 100644 --- a/manual/src/reference/types.md +++ b/manual/src/reference/types.md @@ -1,30 +1,42 @@ # Types -Andy C++ is currently a dynamically typed language, that means that type checks are performed at runtime. Although -you currently can't annotate your variables using type names they do have types at runtime. - -The type system is hierarchical with the root type being `Any`: - -* Any - * [Option](./types/option.md) - * [Boolean](./types/boolean.md) - * [Number](./types/number.md) - * Integer - * Int64 (64bit signed) - * Bigint (unlimited size) - * Float - * Complex - * Rational - * Sequence - * [String](./types/string.md): A mutable list of characters - * [List](./types/list.md): A mutable list - * [Tuple](./types/tuple.md): An immutable list - * [Unit](./types/unit.md) - * [Map](./types/map-and-set.md): A hashmap that associates keys with values - * [Deque](./types/deque.md): A double ended queue - * [MinHeap & MaxHeap](./types/min-max-heap.md): Min/max Heap - * Iterator: A type that can be consumed and produces values (Currently only used for range expressions like `5..100`) - * [Function](./types/function.md) - -> **Note:** `Any` is the base type for all other types. When you declare a function, its arguments default to type `Any`. -> Currently, the `Any` type is implicit and does not appear explicitly in the language. +Andy C++ runs as a dynamically typed language — values carry their types at runtime and most checking happens then. You can also attach type annotations to variables, function parameters, and return values, and the analyser will use them to flag obvious mismatches before the program runs. + +The type system is hierarchical with `Any` at the root: + +* `Any` + * [`Option`](./types/option.md) + * [`Bool`](./types/boolean.md) + * [`Number`](./types/number.md) + * `Int` — machine `i64` or arbitrary-precision `BigInt`, picked automatically + * `Float` + * `Complex` + * `Rational` + * `Sequence` + * [`String`](./types/string.md): a mutable list of characters + * [`List`](./types/list.md): a mutable list + * [`Tuple`](./types/tuple.md): an immutable list + * [`()`](./types/unit.md): unit, the empty tuple + * [`Map`](./types/map-and-set.md): a hashmap that associates keys with values + * [`Deque`](./types/deque.md): a double-ended queue + * [`MinHeap` / `MaxHeap`](./types/min-max-heap.md): min/max heap + * `Iterator`: produces values when consumed (currently only from range expressions like `5..100`) + * [`Function`](./types/function.md) + +These are also the names you write in annotations. Generic types take their parameters in angle brackets: + +```ndc +let xs: List = [1, 2, 3]; +let table: Map = %{"a": 1, "b": 2}; +let maybe: Option = Some("hi"); +let pair: Tuple = (1, "hi"); +let pair2: (Int, String) = (1, "hi"); // tuple shorthand +``` + +Nested generics work too — the parser handles the `>>` ambiguity for you: + +```ndc +let grid: List> = [[1, 2], [3, 4]]; +``` + +> **Note:** `Any` is the base type for every other type, so an `Any`-annotated binding will accept anything. When a parameter or value has no annotation and the analyser can't infer a type, it falls back to `Any`. There is also a `Never` type used internally for things like `break` that don't produce a value — you'll rarely need to write it by hand. diff --git a/manual/src/reference/types/function.md b/manual/src/reference/types/function.md index 41d6abe6..5e429033 100644 --- a/manual/src/reference/types/function.md +++ b/manual/src/reference/types/function.md @@ -89,6 +89,33 @@ let x = fn(y) => y, 3; let x = fn(y) => (y, 3); ``` +## Type annotations + +Parameters and return values can carry type annotations, just like `let` bindings: + +```ndc +fn greet(name: String) -> String => "hello " <> name; + +fn add(x: Int, y: Int) -> Int { + x + y +} +``` + +Annotations are optional — leave them off and the parameter is treated as `Any`. Mix and match as you like: + +```ndc +fn first(xs: List) => xs[0]; // params annotated, return inferred +fn count(xs) -> Int => len(xs); // return annotated, params inferred +``` + +If the body produces a value that doesn't fit the declared return type, the analyser flags it: + +```ndc +fn bad() -> Int { "hello" } // ERROR: mismatched types +``` + +A return-type annotation also helps the analyser understand recursive calls — without it, a recursive call resolves against an unknown return type and you can lose precision. + ## Function overloading You can overload functions by declaring multiple `fn` definitions with the same name and different parameter counts. @@ -108,7 +135,7 @@ fn foo(a) { a + 1 } fn foo(a) { a + 2 } // ERROR: redefinition of 'foo' with 1 parameter ``` -> **Note:** The engine can also overload functions by argument type, and the standard library uses that support in a few places. You cannot write those overloads in user code yet because the language does not let you declare argument types in function signatures. +> **Note:** The engine can also dispatch by argument type, and the standard library uses that to register specialised overloads (for example, an `Int`-only fast path for `+`). User code can't declare two overloads with the same name and arity yet, even when the parameter types differ — the resolver only distinguishes overloads by parameter count. ## Function shadowing diff --git a/manual/src/reference/variables-and-scopes.md b/manual/src/reference/variables-and-scopes.md index 061cbe0a..45753cb6 100644 --- a/manual/src/reference/variables-and-scopes.md +++ b/manual/src/reference/variables-and-scopes.md @@ -58,6 +58,43 @@ pos = ("a", "b"); // type is still Sequence > **Tip:** For the best type inference, initialize variables with a value that matches the intended type. For example, use `let pos = (0, 0);` instead of `let pos = ();` if you intend to store a 2-tuple of numbers. +## Type annotations + +You can pin a variable's type by adding `: Type` after the name. The initialiser still has to fit, the analyser just checks it for you up front. + +```ndc +let count: Int = 0; +let name: String = "world"; +let xs: List = [1, 2, 3]; +``` + +A subtype is fine — `Int` fits where `Number` is asked for, and so on: + +```ndc +let n: Number = 3; // OK: Int is a Number +let x: Any = "anything"; // OK: everything is Any +``` + +A mismatch is rejected with a `mismatched types` error: + +```ndc +let x: Int = "hello"; // ERROR: mismatched types: found String but expected Int +``` + +Once a binding has an annotation, it stays locked to that type. Reassignment and augmented assignment can't widen it the way they widen an inferred binding: + +```ndc +let x: Int = 5; +x = "test"; // ERROR: mismatched types +x /= 2; // ERROR: division can produce a Rational, which doesn't fit in Int +``` + +If you want a binding that widens freely, just leave the annotation off. Annotations are opt-in. + +The same syntax shows up on function parameters and return types — see the [Function](./types/function.md) page. + +See [Types](./types.md) for the full list of names you can use, including generics like `List`, `Map`, and tuple shorthand `(Int, String)`. + ## Destructuring Destructuring works more like Python than Rust. Commas matter more than the delimiters, so `[]` and `()` both work. diff --git a/ndc_analyser/src/analyser.rs b/ndc_analyser/src/analyser.rs index 2dae6a01..397441ab 100644 --- a/ndc_analyser/src/analyser.rs +++ b/ndc_analyser/src/analyser.rs @@ -1,11 +1,14 @@ use std::collections::HashMap; use std::fmt::Debug; -use crate::scope::ScopeTree; -use itertools::Itertools; +use crate::scope::{ScopeTree, TypeBinding}; +use itertools::{Itertools, izip}; use ndc_core::{StaticType, TypeSignature}; use ndc_lexer::Span; -use ndc_parser::{Binding, Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; +use ndc_parser::{ + Binding, Expression, ExpressionLocation, ForBody, ForIteration, FunctionParameter, Lvalue, + NodeId, +}; /// Side table holding semantic information keyed by AST node identity. /// Keeps tooling-specific data (like per-expression types) out of the AST. @@ -13,6 +16,9 @@ use ndc_parser::{Binding, Expression, ExpressionLocation, ForBody, ForIteration, pub struct AnalysisResult { /// Maps each expression node to its inferred result type. pub expr_types: HashMap, + /// Inferred return types for functions without explicit annotations. + /// Keyed by the FunctionDeclaration's `NodeId`. + pub inferred_return_types: HashMap, /// Errors accumulated during analysis. Non-empty when the analyser /// encountered problems but was able to continue with fallback types. pub errors: Vec, @@ -98,7 +104,9 @@ impl Analyser { fn analyse_inner( &mut self, ExpressionLocation { - expression, span, .. + expression, + span, + id, }: &mut ExpressionLocation, ) -> Result { match expression { @@ -142,25 +150,41 @@ impl Analyser { Ok(StaticType::Bool) } Expression::Grouping(expr) => self.analyse(expr), - Expression::VariableDeclaration { l_value, value } => { - let typ = self.analyse_or_any(value); - self.resolve_lvalue_declarative(l_value, typ, *span); + Expression::VariableDeclaration { + l_value, + annotated_type, + value, + } => { + let found_type = self.analyse_or_any(value); + + self.resolve_lvalue_declarative( + l_value, + annotated_type.to_owned(), + found_type.clone(), + *span, + ); Ok(StaticType::unit()) } Expression::Assignment { l_value, r_value } => { let old_type = self.resolve_lvalue_or_any(l_value, *span); let new_type = self.analyse_or_any(r_value); - // Widen the binding's type to the LUB so subsequent uses - // see the broader type. if let Lvalue::Identifier { resolved: Some(target), .. } = l_value { let widened = old_type.lub(&new_type); - if widened != old_type { - self.scope_tree.update_binding_type(*target, widened); + if widened != old_type + && let Err(annotated_type) = + self.scope_tree.update_binding_type(*target, widened) + && !new_type.is_subtype(&annotated_type) + { + self.emit(AnalysisError::mismatched_types( + &new_type, + &annotated_type, + *span, + )); } } @@ -190,42 +214,112 @@ impl Analyser { )); } + // Determine the result type of the operation + let result_type = match resolved_operation { + Binding::Resolved(res) => { + if let StaticType::Function { return_type, .. } = + self.scope_tree.get_type(*res) + { + Some(return_type.as_ref().clone()) + } else { + None + } + } + _ => None, + }; + + if let Some(result_type) = result_type { + match l_value { + // Direct variable: widen or reject if annotated + Lvalue::Identifier { + resolved: Some(target), + .. + } => { + let widened = arg_types[0].lub(&result_type); + if widened != arg_types[0] + && let Err(annotated_type) = + self.scope_tree.update_binding_type(*target, widened) + && !result_type.is_subtype(&annotated_type) + { + self.emit(AnalysisError::mismatched_types( + &result_type, + &annotated_type, + *span, + )); + } + } + // Index into a container: widen the container's type + Lvalue::Index { value, .. } => { + if let Expression::Identifier { + resolved: Binding::Resolved(target), + .. + } = &value.expression + { + let container_type = self.scope_tree.get_type(*target).clone(); + if let Some(elem_type) = container_type.index_element_type() { + let widened_elem = elem_type.lub(&result_type); + if widened_elem != elem_type { + let new_container = + container_type.with_element_type(widened_elem); + let _ = self + .scope_tree + .update_binding_type(*target, new_container); + } + } + } + } + _ => {} + } + } + Ok(StaticType::unit()) } Expression::FunctionDeclaration { name, resolved_name, - type_signature, + parameters, body, return_type: return_type_slot, captures, .. } => { + let type_signature = FunctionParameter::from_params(parameters); + // Pre-register the function before analysing its body so recursive calls can // resolve the name. The return type is unknown at this point so we use Any. - let pre_slot = if let Some(name) = name { - let arity = type_signature.types().map(|t| t.len()); - if self.scope_tree.has_function_in_current_scope(name, arity) { - self.emit(AnalysisError::function_redefinition(name, arity, *span)); - // Skip re-registering but still analyse the body below. - None + let pre_slot = + if let Some(name) = name { + let arity = type_signature.types().map(|t| t.len()); + if self.scope_tree.has_function_in_current_scope(name, arity) { + self.emit(AnalysisError::function_redefinition(name, arity, *span)); + // Skip re-registering but still analyse the body below. + None + } else { + let placeholder = StaticType::Function { + parameters: type_signature.types(), + return_type: Box::new( + return_type_slot.clone().unwrap_or(StaticType::Any), + ), + }; + Some(self.scope_tree.create_local_binding( + name.clone(), + TypeBinding::Inferred(placeholder), + )) + } } else { - let placeholder = StaticType::Function { - parameters: type_signature.types(), - return_type: Box::new(StaticType::Any), - }; - Some( - self.scope_tree - .create_local_binding(name.clone(), placeholder), - ) - } - } else { - None - }; + None + }; self.scope_tree.new_function_scope(); self.return_type_stack.push(None); - let param_types = self.resolve_parameters_declarative(type_signature, *span); + let param_types = self.resolve_parameters_declarative(&type_signature, *span); + + // Fill inferred_type on parameter Lvalues for LSP hints. + for (p, typ) in parameters.iter_mut().zip(¶m_types) { + if let Lvalue::Identifier { inferred_type, .. } = &mut p.lvalue { + *inferred_type = Some(typ.clone()); + } + } let implicit_return = self.analyse_or_any(body); let explicit_return = self.return_type_stack.pop().unwrap(); @@ -233,23 +327,37 @@ impl Analyser { self.scope_tree.destroy_scope(); // Combine explicit `return` types with the block's implicit return type. - let return_type = match explicit_return { + let inferred_return = match explicit_return { Some(ret) => ret.lub(&implicit_return), None => implicit_return, }; - *return_type_slot = Some(return_type); + + // If there is an annotated return type, validate it; + // otherwise record the inferred type in the side table. + if let Some(annotated) = return_type_slot { + if !inferred_return.is_subtype(annotated) { + self.emit(AnalysisError::mismatched_types( + &inferred_return, + annotated, + *span, + )); + } + } else { + self.result + .inferred_return_types + .insert(*id, inferred_return.clone()); + } + + let effective_return = return_type_slot.clone().unwrap_or(inferred_return); let function_type = StaticType::Function { parameters: Some(param_types.clone()), - return_type: Box::new( - return_type_slot - .clone() - .expect("must have a value at this point"), - ), + return_type: Box::new(effective_return), }; if let Some(slot) = pre_slot { - self.scope_tree + let _ = self + .scope_tree .update_binding_type(slot, function_type.clone()); *resolved_name = Some(slot); } @@ -401,10 +509,22 @@ impl Analyser { } Binding::Resolved(res) => self.scope_tree.get_type(*res).clone(), - Binding::Dynamic(_) => StaticType::Function { - parameters: None, - return_type: Box::new(StaticType::Any), - }, + Binding::Dynamic(candidates) => { + let return_type = candidates + .iter() + .map(|c| self.scope_tree.get_type(*c).clone()) + .filter_map(|t| match t { + StaticType::Function { return_type, .. } => Some(*return_type), + _ => None, + }) + .reduce(|a, b| a.lub(&b)) + .unwrap_or(StaticType::Any); + + StaticType::Function { + parameters: None, + return_type: Box::new(return_type), + } + } }; *resolved = binding; @@ -429,13 +549,14 @@ impl Analyser { self.scope_tree.new_iteration_scope(); - self.resolve_lvalue_declarative( - l_value, - sequence_type - .sequence_element_type() - .unwrap_or(StaticType::Any), - span, - ); + let found_type = sequence_type + .sequence_element_type() + .unwrap_or(StaticType::Any); + + // TOOD: get this from the AST when the parser adds it + let expected_type = None; + + self.resolve_lvalue_declarative(l_value, expected_type, found_type, span); do_destroy = true; } ForIteration::Guard(expr) => { @@ -585,33 +706,62 @@ impl Analyser { let mut seen_names: Vec<&str> = Vec::new(); for param in parameters { - types.push(StaticType::Any); + let has_annotation = param.type_name != StaticType::Any; + let binding = if has_annotation { + TypeBinding::Annotated(param.type_name.clone()) + } else { + TypeBinding::Inferred(StaticType::Any) + }; + + types.push(param.type_name.clone()); if seen_names.contains(¶m.name.as_str()) { self.emit(AnalysisError::parameter_redefined(¶m.name, span)); - // Skip duplicate but continue checking remaining params. continue; } seen_names.push(¶m.name); self.scope_tree - .create_local_binding(param.name.clone(), StaticType::Any); + .create_local_binding(param.name.clone(), binding); } types } - fn resolve_lvalue_declarative(&mut self, lvalue: &mut Lvalue, typ: StaticType, span: Span) { + fn resolve_lvalue_declarative( + &mut self, + lvalue: &mut Lvalue, + expected_type: Option, + found_type: StaticType, + span: Span, + ) { match lvalue { Lvalue::Identifier { identifier, resolved, inferred_type, - .. + span, } => { + // If there is a type annotation and the given type is not a subtype of the annotated type we emit an error + if let Some(expected_type) = &expected_type + && !found_type.is_subtype(expected_type) + { + self.emit(AnalysisError::mismatched_types( + &found_type, + expected_type, + *span, + )); + } + + let type_binding = match expected_type { + Some(annotated) => TypeBinding::Annotated(annotated), + None => TypeBinding::Inferred(found_type), + }; + *resolved = Some( self.scope_tree - .create_local_binding(identifier.clone(), typ.clone()), + .create_local_binding(identifier.clone(), type_binding.clone()), ); - *inferred_type = Some(typ); + + *inferred_type = Some(type_binding.typ().clone()) } Lvalue::Index { index, value, .. } => { self.analyse_or_any(index); @@ -623,25 +773,45 @@ impl Analyser { // can happen when a variable is declared with one type (e.g. ()) // and later reassigned to a tuple of a different arity — the // analyser doesn't track reassignment types. + let is_annotated = expected_type.is_some(); + let resolved_type = expected_type.unwrap_or(found_type.clone()); + let sub_types: Box> = - if let StaticType::Tuple(elems) = &typ { + if let StaticType::Tuple(elems) = &resolved_type { if elems.len() != seq.len() { - Box::new(std::iter::repeat(&StaticType::Any)) + self.emit(AnalysisError::tuple_arity_mismatch( + seq.len(), + elems.len(), + span, + )); + return; } else { Box::new(elems.iter()) } - } else if let Some(iter) = typ.unpack() { + } else if let Some(iter) = resolved_type.unpack() { iter } else { - self.emit(AnalysisError::unable_to_unpack_type(&typ, span)); + self.emit(AnalysisError::unable_to_unpack_type(&resolved_type, span)); return; }; - for (sub_lvalue, sub_lvalue_type) in seq.iter_mut().zip(sub_types) { + let found_types = found_type + .unpack() + .unwrap_or_else(|| Box::new(std::iter::repeat(&StaticType::Any))); + + for (sub_lvalue, sub_type, found_type) in + izip!(seq.iter_mut(), sub_types, found_types) + { + let sub_expected = if is_annotated { + Some(sub_type.clone()) + } else { + None + }; self.resolve_lvalue_declarative( sub_lvalue, - sub_lvalue_type.clone(), - /* todo: figure out how to narrow this span */ span, + sub_expected, + found_type.clone(), + span, ); } } @@ -678,6 +848,22 @@ impl AnalysisError { pub fn span(&self) -> Span { self.span } + fn tuple_arity_mismatch(ident_len: usize, annotation_len: usize, span: Span) -> Self { + Self { + text: format!( + "mismatched tuple arity: found a len={ident_len} identifier and a len={annotation_len} annotation." + ), + span, + } + } + + fn mismatched_types(found: &StaticType, expected: &StaticType, span: Span) -> Self { + Self { + text: format!("mismatched types: found {found} but expected {expected}"), + span, + } + } + fn function_redefinition(name: &str, arity: Option, span: Span) -> Self { let arity_desc = match arity { Some(n) => format!("{n} parameter{}", if n == 1 { "" } else { "s" }), diff --git a/ndc_analyser/src/scope.rs b/ndc_analyser/src/scope.rs index 69107bad..14eecade 100644 --- a/ndc_analyser/src/scope.rs +++ b/ndc_analyser/src/scope.rs @@ -2,13 +2,37 @@ use ndc_core::StaticType; use ndc_parser::{Binding, CaptureSource, ResolvedVar}; use std::fmt::{Debug, Formatter}; +#[derive(Debug, Clone)] +pub(crate) enum TypeBinding { + Inferred(StaticType), + Annotated(StaticType), +} + +impl TypeBinding { + pub fn typ(&self) -> &StaticType { + match self { + Self::Inferred(t) | Self::Annotated(t) => t, + } + } + + pub fn is_annotated(&self) -> bool { + matches!(self, Self::Annotated(_)) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ScopeBinding { + pub name: String, + pub binding: TypeBinding, +} + #[derive(Debug, Clone)] pub(crate) struct Scope { parent_idx: Option, creates_environment: bool, // Only true for function scopes and for-loop iterations base_offset: usize, function_scope_idx: usize, - identifiers: Vec<(String, StaticType)>, + identifiers: Vec, upvalues: Vec<(String, CaptureSource)>, } @@ -56,7 +80,7 @@ impl Scope { pub(crate) fn find_slot_by_name(&self, find_ident: &str) -> Option { self.identifiers .iter() - .rposition(|(ident, _)| ident == find_ident) + .rposition(|b| b.name == find_ident) .map(|idx| idx + self.base_offset) } @@ -68,8 +92,8 @@ impl Scope { self.identifiers .iter() .enumerate() - .filter_map(|(slot, (ident, typ))| { - if ident == find_ident && typ.could_be_callable() { + .filter_map(|(slot, b)| { + if b.name == find_ident && b.binding.typ().could_be_callable() { Some(slot + self.base_offset) } else { None @@ -82,13 +106,13 @@ impl Scope { self.identifiers.iter() .enumerate() .rev() - .filter_map(|(slot, (ident, typ))| { - if ident != find_ident { + .filter_map(|(slot, b)| { + if b.name != find_ident { return None; } // If the thing is not a function we're not interested - let StaticType::Function { parameters, .. } = typ else { + let StaticType::Function { parameters, .. } = b.binding.typ() else { return None; }; @@ -107,17 +131,17 @@ impl Scope { fn find_function(&self, find_ident: &str, find_types: &[StaticType]) -> Option { self.identifiers .iter() - .rposition(|(ident, typ)| ident == find_ident && typ.is_fn_and_matches(find_types)) + .rposition(|b| b.name == find_ident && b.binding.typ().is_fn_and_matches(find_types)) .map(|idx| idx + self.base_offset) } /// Check if this scope already contains a function with the given name and arity. fn has_function_with_arity(&self, name: &str, arity: Option) -> bool { - self.identifiers.iter().any(|(ident, typ)| { - if ident != name { + self.identifiers.iter().any(|b| { + if b.name != name { return false; } - match typ { + match b.binding.typ() { StaticType::Function { parameters: Some(params), .. @@ -133,9 +157,11 @@ impl Scope { }) } - fn allocate(&mut self, name: String, typ: StaticType) -> usize { - self.identifiers.push((name, typ)); - // Slot is just the length of the list minus one + fn allocate(&mut self, name: String, type_binding: TypeBinding) -> usize { + self.identifiers.push(ScopeBinding { + name, + binding: type_binding, + }); self.base_offset + self.identifiers.len() - 1 } @@ -184,7 +210,13 @@ impl ScopeTree { /// user-level shadowing. pub fn from_global_scope(global_scope_map: Vec<(String, StaticType)>) -> Self { let mut global_scope = Scope::new_function_scope(None, 0); - global_scope.identifiers = global_scope_map; + global_scope.identifiers = global_scope_map + .into_iter() + .map(|(name, typ)| ScopeBinding { + name, + binding: TypeBinding::Inferred(typ), + }) + .collect(); Self { current_scope_idx: 0, @@ -217,7 +249,7 @@ impl ScopeTree { } } } - ResolvedVar::Global { slot } => &self.global_scope.identifiers[slot].1, + ResolvedVar::Global { slot } => self.global_scope.identifiers[slot].binding.typ(), } } @@ -233,7 +265,7 @@ impl ScopeTree { loop { let scope = &self.scopes[scope_idx]; if slot >= scope.base_offset && slot < scope.base_offset + scope.identifiers.len() { - return &scope.identifiers[slot - scope.base_offset].1; + return scope.identifiers[slot - scope.base_offset].binding.typ(); } scope_idx = scope .parent_idx @@ -407,9 +439,13 @@ impl ScopeTree { Binding::None } - pub(crate) fn create_local_binding(&mut self, ident: String, typ: StaticType) -> ResolvedVar { + pub(crate) fn create_local_binding( + &mut self, + ident: String, + binding: TypeBinding, + ) -> ResolvedVar { ResolvedVar::Local { - slot: self.scopes[self.current_scope_idx].allocate(ident, typ), + slot: self.scopes[self.current_scope_idx].allocate(ident, binding), } } @@ -427,15 +463,31 @@ impl ScopeTree { /// Uses `"\x00"` as a sentinel name that can never collide with user identifiers /// since the lexer never produces null bytes. pub(crate) fn reserve_anonymous_slot(&mut self) -> usize { - self.scopes[self.current_scope_idx].allocate("\x00".to_string(), StaticType::Any) + self.scopes[self.current_scope_idx] + .allocate("\x00".to_string(), TypeBinding::Inferred(StaticType::Any)) + } + + /// Try to update a binding's type. Returns `Err` with the annotated type + /// if the binding has an explicit type annotation and cannot be widened. + pub(crate) fn update_binding_type( + &mut self, + var: ResolvedVar, + new_type: StaticType, + ) -> Result<(), StaticType> { + let binding = self.get_binding_mut(var); + if binding.is_annotated() { + return Err(binding.typ().clone()); + } + *binding = TypeBinding::Inferred(new_type); + Ok(()) } - pub(crate) fn update_binding_type(&mut self, var: ResolvedVar, new_type: StaticType) { + fn get_binding_mut(&mut self, var: ResolvedVar) -> &mut TypeBinding { match var { ResolvedVar::Local { slot } => { let scope_idx = self.find_scope_owning_slot(self.current_scope_idx, slot); let base = self.scopes[scope_idx].base_offset; - self.scopes[scope_idx].identifiers[slot - base].1 = new_type; + &mut self.scopes[scope_idx].identifiers[slot - base].binding } ResolvedVar::Upvalue { slot } => { let mut scope_idx = self.scopes[self.current_scope_idx].function_scope_idx; @@ -451,8 +503,7 @@ impl ScopeTree { .expect("expected parent scope"); let owning = self.find_scope_owning_slot(parent, local_slot); let base = self.scopes[owning].base_offset; - self.scopes[owning].identifiers[local_slot - base].1 = new_type; - return; + return &mut self.scopes[owning].identifiers[local_slot - base].binding; } CaptureSource::Upvalue(uv_slot) => { scope_idx = self.get_parent_function_scope_idx(scope_idx); @@ -462,7 +513,7 @@ impl ScopeTree { } } ResolvedVar::Global { .. } => { - panic!("update_binding_type called with a global binding") + unreachable!("get_binding_mut called with a global binding") } } } @@ -482,7 +533,7 @@ impl ScopeTree { /// Given a local slot found during a scope walk, return the appropriate `ResolvedVar`. /// If `env_scopes` is empty the slot is in the current function scope and can be - /// referenced directly as a `Local`. Otherwise it must be hoisted through intervening + /// referenced directly as a `Local`. Otherwise, it must be hoisted through intervening /// function scopes as an upvalue chain. fn resolve_found_local( &mut self, @@ -584,7 +635,7 @@ mod tests { #[test] fn single_local_in_function_scope() { let mut tree = empty_scope_tree(); - let var = tree.create_local_binding("x".into(), StaticType::Int); + let var = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(var, ResolvedVar::Local { slot: 0 }); assert_eq!( tree.get_binding_any("x"), @@ -595,9 +646,9 @@ mod tests { #[test] fn multiple_locals_get_ascending_slots() { let mut tree = empty_scope_tree(); - let x = tree.create_local_binding("x".into(), StaticType::Int); - let y = tree.create_local_binding("y".into(), StaticType::Int); - let z = tree.create_local_binding("z".into(), StaticType::Int); + let x = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); + let z = tree.create_local_binding("z".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(x, ResolvedVar::Local { slot: 0 }); assert_eq!(y, ResolvedVar::Local { slot: 1 }); assert_eq!(z, ResolvedVar::Local { slot: 2 }); @@ -606,11 +657,11 @@ mod tests { #[test] fn block_scope_continues_flat_numbering() { let mut tree = empty_scope_tree(); - let x = tree.create_local_binding("x".into(), StaticType::Int); + let x = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(x, ResolvedVar::Local { slot: 0 }); tree.new_block_scope(); - let y = tree.create_local_binding("y".into(), StaticType::Int); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(y, ResolvedVar::Local { slot: 1 }); assert_eq!( @@ -622,21 +673,21 @@ mod tests { #[test] fn nested_block_scopes_continue_numbering() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); - let b = tree.create_local_binding("b".into(), StaticType::Int); + let b = tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(b, ResolvedVar::Local { slot: 1 }); tree.new_block_scope(); - let c = tree.create_local_binding("c".into(), StaticType::Int); + let c = tree.create_local_binding("c".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(c, ResolvedVar::Local { slot: 2 }); } #[test] fn block_scope_does_not_create_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); assert_eq!( @@ -648,10 +699,10 @@ mod tests { #[test] fn function_scope_resets_slots_and_captures_as_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); - let y = tree.create_local_binding("y".into(), StaticType::Int); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(y, ResolvedVar::Local { slot: 0 }); assert_eq!( @@ -663,10 +714,10 @@ mod tests { #[test] fn iteration_scope_continues_numbering_and_is_transparent() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_iteration_scope(); - let i = tree.create_local_binding("i".into(), StaticType::Int); + let i = tree.create_local_binding("i".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(i, ResolvedVar::Local { slot: 1 }); assert_eq!( @@ -694,21 +745,21 @@ mod tests { #[test] fn slot_reuse_after_scope_destroy() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); - tree.create_local_binding("b".into(), StaticType::Int); + tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::Int)); tree.destroy_scope(); - let c = tree.create_local_binding("c".into(), StaticType::Int); + let c = tree.create_local_binding("c".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(c, ResolvedVar::Local { slot: 1 }); } #[test] fn get_type_returns_correct_type() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); - tree.create_local_binding("y".into(), StaticType::String); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); + tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::String)); assert_eq!( tree.get_type(ResolvedVar::Local { slot: 0 }), @@ -727,7 +778,7 @@ mod tests { #[test] fn upvalue_hoisting_across_two_function_scopes() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // outer tree.new_function_scope(); // inner @@ -754,8 +805,8 @@ mod tests { #[test] fn multiple_upvalues_get_distinct_indices() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); - tree.create_local_binding("b".into(), StaticType::String); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); + tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::String)); tree.new_function_scope(); @@ -770,7 +821,7 @@ mod tests { #[test] fn duplicate_upvalue_resolution_reuses_index() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); @@ -789,7 +840,7 @@ mod tests { #[test] fn get_type_follows_upvalue_chain() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // outer tree.new_function_scope(); // inner @@ -805,7 +856,7 @@ mod tests { #[test] fn sibling_closure_finds_existing_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // middle diff --git a/ndc_core/src/static_type.rs b/ndc_core/src/static_type.rs index aff10d29..3e13ed78 100644 --- a/ndc_core/src/static_type.rs +++ b/ndc_core/src/static_type.rs @@ -49,6 +49,17 @@ impl TypeSignature { } } + pub fn from_annotated_bindings(bindings: Vec<(String, Option)>) -> Self { + Self::Exact( + bindings + .into_iter() + .map(|(name, annotation)| { + Parameter::new(name, annotation.unwrap_or(StaticType::Any)) + }) + .collect(), + ) + } + pub fn types(&self) -> Option> { match self { Self::Variadic => None, @@ -108,7 +119,127 @@ pub enum StaticType { Deque(Box), } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct StaticTypeConstructionError { + message: String, + help_text: String, +} + +impl StaticTypeConstructionError { + fn new, H: Into>(message: M, help_text: H) -> Self { + Self { + message: message.into(), + help_text: help_text.into(), + } + } + + pub fn help_text(&self) -> &str { + &self.help_text + } +} + +impl fmt::Display for StaticTypeConstructionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.message.fmt(f) + } +} + impl StaticType { + pub fn from_name_and_args( + name: &str, + args: Vec, + ) -> Result { + match name { + "Any" => Self::require_no_args(name, &args).map(|_| Self::Any), + "Never" => Self::require_no_args(name, &args).map(|_| Self::Never), + "Bool" => Self::require_no_args(name, &args).map(|_| Self::Bool), + "Number" => Self::require_no_args(name, &args).map(|_| Self::Number), + "Float" => Self::require_no_args(name, &args).map(|_| Self::Float), + "Int" => Self::require_no_args(name, &args).map(|_| Self::Int), + "Rational" => Self::require_no_args(name, &args).map(|_| Self::Rational), + "Complex" => Self::require_no_args(name, &args).map(|_| Self::Complex), + "String" => Self::require_no_args(name, &args).map(|_| Self::String), + "Option" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Option(Box::new(elem))) + } + "Sequence" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Sequence(Box::new(elem))) + } + "List" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::List(Box::new(elem))) + } + "Iterator" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Iterator(Box::new(elem))) + } + "MinHeap" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::MinHeap(Box::new(elem))) + } + "MaxHeap" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::MaxHeap(Box::new(elem))) + } + "Deque" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Deque(Box::new(elem))) + } + "Tuple" => Self::require_at_least_one_arg(name, args).map(Self::Tuple), + "Map" => { + let [key, value] = Self::require_exactly_n_args::<2>(name, args)?; + Ok(Self::Map { + key: Box::new(key), + value: Box::new(value), + }) + } + _ => Err(StaticTypeConstructionError::new( + format!("unknown type `{name}`"), + "Use a valid type name in this annotation.", + )), + } + } + + fn require_no_args(name: &str, args: &[Self]) -> Result<(), StaticTypeConstructionError> { + if args.is_empty() { + Ok(()) + } else { + Err(StaticTypeConstructionError::new( + format!("type `{name}` does not take generic arguments"), + format!("Remove the generic arguments from `{name}`."), + )) + } + } + + fn require_exactly_one_arg( + name: &str, + args: Vec, + ) -> Result { + let [arg] = Self::require_exactly_n_args(name, args)?; + Ok(arg) + } + + fn require_exactly_n_args( + name: &str, + args: Vec, + ) -> Result<[Self; N], StaticTypeConstructionError> { + args.try_into().map_err(|_err: Vec| { + StaticTypeConstructionError::new( + format!("type `{name}` expects exactly {N} generic arguments"), + format!("Use `{name}<...>` with {N} type arguments."), + ) + }) + } + + fn require_at_least_one_arg( + name: &str, + args: Vec, + ) -> Result, StaticTypeConstructionError> { + if args.is_empty() { + Err(StaticTypeConstructionError::new( + format!("type `{name}` requires generic arguments"), + format!("Add generic arguments like `{name}<...>`."), + )) + } else { + Ok(args) + } + } + /// Checks if `self` is a subtype of `other`. /// /// A type S is a subtype of T (S <: T) if a value of type S can be safely @@ -469,6 +600,21 @@ impl StaticType { !self.is_subtype(other) && !other.is_subtype(self) } + /// Returns a new type with the element type replaced. For container types + /// like `List`, this returns `List`. Returns `None` if the + /// type does not have a replaceable element type. + pub fn with_element_type(&self, new_elem: Self) -> Self { + match self { + Self::List(_) => Self::List(Box::new(new_elem)), + Self::Sequence(_) => Self::Sequence(Box::new(new_elem)), + Self::Iterator(_) => Self::Iterator(Box::new(new_elem)), + Self::MinHeap(_) => Self::MinHeap(Box::new(new_elem)), + Self::MaxHeap(_) => Self::MaxHeap(Box::new(new_elem)), + Self::Deque(_) => Self::Deque(Box::new(new_elem)), + _ => self.clone(), + } + } + pub fn index_element_type(&self) -> Option { if let Self::Map { value, .. } = self { return Some(value.as_ref().clone()); diff --git a/ndc_lsp/src/features/inlay_hints.rs b/ndc_lsp/src/features/inlay_hints.rs index 9883883e..339e8c94 100644 --- a/ndc_lsp/src/features/inlay_hints.rs +++ b/ndc_lsp/src/features/inlay_hints.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use ndc_core::StaticType; use ndc_interpreter::AnalysisResult; use ndc_lexer::Span; -use ndc_parser::ExpressionLocation; +use ndc_parser::{ExpressionLocation, NodeId}; use tower_lsp::lsp_types::{InlayHint, InlayHintKind, InlayHintLabel}; use crate::util::position_from_offset; @@ -54,35 +54,124 @@ impl AstVisitor for HintCollector<'_> { } } - fn on_declaration(&mut self, identifier: &str, inferred_type: Option<&StaticType>, span: Span) { + fn on_declaration( + &mut self, + identifier: &str, + inferred_type: Option<&StaticType>, + has_annotation: bool, + span: Span, + ) { if let Some(typ) = inferred_type { - self.hints.push(InlayHint { - position: position_from_offset(self.text, span.end()), - label: InlayHintLabel::String(format!(": {typ}")), - kind: Some(InlayHintKind::TYPE), - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: Some(true), - data: None, - }); + if !has_annotation { + self.hints.push(InlayHint { + position: position_from_offset(self.text, span.end()), + label: InlayHintLabel::String(format!(": {typ}")), + kind: Some(InlayHintKind::TYPE), + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: Some(true), + data: None, + }); + } self.variable_types .insert(identifier.to_string(), typ.clone()); } } - fn on_function_declaration(&mut self, return_type: Option<&StaticType>, parameters_span: Span) { - if let Some(rt) = return_type { - self.hints.push(InlayHint { - position: position_from_offset(self.text, parameters_span.end()), - label: InlayHintLabel::String(format!(" -> {rt}")), - kind: Some(InlayHintKind::TYPE), - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: None, - data: None, - }); + fn on_function_declaration( + &mut self, + return_type: Option<&StaticType>, + parameters_span: Span, + node_id: NodeId, + ) { + // return_type is Some only when explicitly annotated by the user — skip the hint. + // Inferred return types are stored in the side table. + if return_type.is_none() { + if let Some(rt) = self.analysis_result.inferred_return_types.get(&node_id) { + self.hints.push(InlayHint { + position: position_from_offset(self.text, parameters_span.end()), + label: InlayHintLabel::String(format!(" -> {rt}")), + kind: Some(InlayHintKind::TYPE), + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: None, + data: None, + }); + } } } } + +#[cfg(test)] +mod tests { + use super::*; + use ndc_interpreter::Interpreter; + + fn collect_hints(source: &str) -> AnalysisInfo { + let mut interpreter = Interpreter::capturing(); + interpreter.configure(ndc_stdlib::register); + let (expressions, analysis_result) = interpreter + .analyse_str(source) + .expect("analysis should succeed"); + collect(&expressions, &analysis_result, source) + } + + #[test] + fn inferred_let_binding_gets_type_inlay() { + let info = collect_hints("let value = 1;"); + assert!( + info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + } + + #[test] + fn annotated_let_binding_skips_type_inlay() { + let info = collect_hints("let value: Int = 1;"); + assert!( + !info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + assert_eq!(info.variable_types.get("value"), Some(&StaticType::Int)); + } + + #[test] + fn annotated_return_type_skips_inlay() { + let info = collect_hints("fn foo(x: Int) -> Int { x + 1 }"); + assert!(!info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label.contains("->")) + )); + } + + #[test] + fn inferred_return_type_gets_inlay() { + let info = collect_hints("fn foo() { 42 }"); + assert!(info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == " -> Int") + )); + } + + #[test] + fn annotated_param_skips_inlay() { + let info = collect_hints("fn foo(x: Int) { x }"); + assert!( + !info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + } + + #[test] + fn unannotated_param_gets_inlay() { + let info = collect_hints("fn foo(x) { x }"); + assert!( + info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Any") + ) + ); + } +} diff --git a/ndc_lsp/src/visitor.rs b/ndc_lsp/src/visitor.rs index f22b934a..417ff525 100644 --- a/ndc_lsp/src/visitor.rs +++ b/ndc_lsp/src/visitor.rs @@ -1,6 +1,6 @@ use ndc_core::StaticType; use ndc_lexer::Span; -use ndc_parser::{Expression, ExpressionLocation, ForBody, ForIteration, Lvalue}; +use ndc_parser::{Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; /// Trait for visiting interesting nodes during an AST walk. /// @@ -13,6 +13,7 @@ pub trait AstVisitor { &mut self, _identifier: &str, _inferred_type: Option<&StaticType>, + _has_annotation: bool, _span: Span, ) { } @@ -25,6 +26,7 @@ pub trait AstVisitor { &mut self, _return_type: Option<&StaticType>, _parameters_span: Span, + _node_id: NodeId, ) { } } @@ -39,17 +41,25 @@ pub fn walk_ast(visitor: &mut impl AstVisitor, expressions: &[ExpressionLocation fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { visitor.on_expression(expr); match &expr.expression { - Expression::VariableDeclaration { l_value, value } => { - walk_lvalue(visitor, l_value); + Expression::VariableDeclaration { + l_value, + annotated_type, + value, + } => { + walk_lvalue(visitor, l_value, annotated_type.is_some()); walk_expression(visitor, value); } Expression::FunctionDeclaration { return_type, + parameters, parameters_span, body, .. } => { - visitor.on_function_declaration(return_type.as_ref(), *parameters_span); + for p in parameters { + walk_lvalue(visitor, &p.lvalue, p.annotation.is_some()); + } + visitor.on_function_declaration(return_type.as_ref(), *parameters_span, expr.id); walk_expression(visitor, body); } Expression::Statement(inner) | Expression::Grouping(inner) => { @@ -82,7 +92,7 @@ fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { for iteration in iterations { match iteration { ForIteration::Iteration { l_value, sequence } => { - walk_lvalue(visitor, l_value); + walk_lvalue(visitor, l_value, false); walk_expression(visitor, sequence); } ForIteration::Guard(expr) => walk_expression(visitor, expr), @@ -113,7 +123,7 @@ fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { } } -fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue) { +fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue, has_annotation: bool) { match lvalue { Lvalue::Identifier { identifier, @@ -121,11 +131,11 @@ fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue) { span, .. } => { - visitor.on_declaration(identifier, inferred_type.as_ref(), *span); + visitor.on_declaration(identifier, inferred_type.as_ref(), has_annotation, *span); } Lvalue::Sequence(lvalues) => { for lv in lvalues { - walk_lvalue(visitor, lv); + walk_lvalue(visitor, lv, has_annotation); } } Lvalue::Index { .. } => {} diff --git a/ndc_parser/src/expression.rs b/ndc_parser/src/expression.rs index f8c0a803..03b4fcae 100644 --- a/ndc_parser/src/expression.rs +++ b/ndc_parser/src/expression.rs @@ -76,6 +76,7 @@ pub enum Expression { Grouping(Box), VariableDeclaration { l_value: Lvalue, + annotated_type: Option, value: Box, }, Assignment { @@ -92,7 +93,7 @@ pub enum Expression { FunctionDeclaration { name: Option, resolved_name: Option, - type_signature: TypeSignature, + parameters: Vec, parameters_span: Span, body: Box, return_type: Option, @@ -170,6 +171,32 @@ pub enum ForBody { }, } +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct FunctionParameter { + pub lvalue: Lvalue, + pub annotation: Option, + pub span: Span, +} + +impl FunctionParameter { + pub fn from_params(params: &[Self]) -> TypeSignature { + TypeSignature::from_annotated_bindings( + params + .iter() + .map(|p| { + let Lvalue::Identifier { identifier, .. } = &p.lvalue else { + unreachable!( + "parameter list may only contain identifiers {:?} found.", + p.lvalue + ); + }; + (identifier.clone(), p.annotation.clone()) + }) + .collect(), + ) + } +} + #[derive(Debug, Eq, PartialEq, Clone)] pub enum Lvalue { // Example: `let foo = ...` diff --git a/ndc_parser/src/lib.rs b/ndc_parser/src/lib.rs index a5227e63..cae9582f 100644 --- a/ndc_parser/src/lib.rs +++ b/ndc_parser/src/lib.rs @@ -3,8 +3,8 @@ mod operator; mod parser; pub use expression::{ - Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId, - ResolvedVar, + Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, + FunctionParameter, Lvalue, NodeId, ResolvedVar, }; pub use operator::{BinaryOperator, LogicalOperator, UnaryOperator}; pub use parser::Error; diff --git a/ndc_parser/src/parser.rs b/ndc_parser/src/parser.rs index ef70156c..077b6828 100644 --- a/ndc_parser/src/parser.rs +++ b/ndc_parser/src/parser.rs @@ -1,7 +1,9 @@ use std::fmt::Write; use crate::expression::Expression; -use crate::expression::{Binding, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; +use crate::expression::{ + Binding, ExpressionLocation, ForBody, ForIteration, FunctionParameter, Lvalue, NodeId, +}; use crate::operator::{BinaryOperator, LogicalOperator, UnaryOperator}; use ndc_core::{Parameter, StaticType, TypeSignature}; use ndc_lexer::{Span, Token, TokenLocation}; @@ -309,16 +311,7 @@ impl Parser { .require_current_token_matches(&Token::Let) .expect("guaranteed to match by caller"); - let maybe_lvalue = self.tuple_expression(Self::single_expression, false)?; - let lvalue_span = maybe_lvalue.span; - - let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { - return Err(Error::with_help( - "Invalid assignment target".to_string(), - lvalue_span, - "Assignment target is not a valid lvalue. Only a few expressions can be assigned a value. Check that the left-hand side of the assignment is a valid target.".to_string(), - )); - }; + let (lvalue, annotated_type) = self.named_binding()?; self.require_current_token_matches(&Token::EqualsSign)?; @@ -326,6 +319,7 @@ impl Parser { let end = expression.span; let declaration = Expression::VariableDeclaration { l_value: lvalue, + annotated_type, value: Box::new(expression), }; @@ -431,25 +425,54 @@ impl Parser { } } + fn delimited_comma_separated( + &mut self, + open: &Token, + close: &Token, + parse_item: fn(&mut Self) -> Result, + allow_empty: bool, + ) -> Result<(Vec, Span), Error> { + let open_span = self.require_current_token_matches(open)?.span; + + if let Some(close_token) = self.consume_token_if(std::slice::from_ref(close)) { + if allow_empty { + return Ok((Vec::new(), open_span.merge(close_token.span))); + } + + return Err(Error::with_help( + format!("expected an item before '{close}'"), + close_token.span, + "This delimited list cannot be empty.".to_string(), + )); + } + + let mut items = vec![parse_item(self)?]; + + while self.consume_token_if(&[Token::Comma]).is_some() { + if self.match_token(std::slice::from_ref(close)).is_some() { + break; + } + + items.push(parse_item(self)?); + } + + let close_span = self.require_current_token_matches(close)?.span; + Ok((items, open_span.merge(close_span))) + } + /// Parses a delimited tuple (enclosed in parentheses) that can be empty fn delimited_tuple( &mut self, next: fn(&mut Self) -> Result, ) -> Result { - let start = self.require_current_token_matches(&Token::LeftParentheses)?; - if let Some(end) = self.consume_token_if(&[Token::RightParentheses]) { - Ok(Expression::Tuple { values: vec![] }.to_location(start.span.merge(end.span))) - } else { - let mut tuple_expression = self.tuple_expression(next, true)?; - let right_paren_span = self - .require_current_token_matches(&Token::RightParentheses)? - .span; - - // Include the right paretheses in the span - tuple_expression.span = tuple_expression.span.merge(right_paren_span); + let (values, span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + next, + true, + )?; - Ok(tuple_expression) - } + Ok(Expression::Tuple { values }.to_location(span)) } fn single_expression(&mut self) -> Result { @@ -1158,7 +1181,22 @@ impl Parser { } }; - let argument_list = self.delimited_tuple(Self::single_expression)?; + // let argument_list = self.delimited_tuple(Self::single_expression)?; + + let (argument_list, parameters_span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + Self::named_parameter, + true, + )?; + + // Optional return type annotation: `-> Type` + let annotated_return_type = if self.peek_current_token() == Some(&Token::RightArrow) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; // Next we either expect a body block `{ ... }` or a fat arrow followed by a single expression `=> ...` @@ -1175,20 +1213,17 @@ impl Parser { "Expected that the argument list is followed by either a body `{}` or a fat arrow `=>`".to_string(), )) } - None => return Err(Error::end_of_input(argument_list.span)), + None => return Err(Error::end_of_input(parameters_span)), }; - let parameters_span = argument_list.span; let span = fn_token.span.merge(body.span); Ok(ExpressionLocation { expression: Expression::FunctionDeclaration { name: identifier, - type_signature: argument_list - .try_into() - .expect("INTERNAL ERROR: type of argument list is incorrect"), + parameters: argument_list, parameters_span, body: Box::new(body), - return_type: None, // At some point in the future we could use type declarations here to insert the type (return type inference is cringe anyway) + return_type: annotated_return_type, resolved_name: None, captures: vec![], pure: is_pure, @@ -1296,6 +1331,171 @@ impl Parser { }; Ok(Expression::Map { values, default }.to_location(map_open_span.merge(map_close_span))) } + + pub fn static_type(&mut self) -> Result { + let Some(TokenLocation { token, span }) = self.peek_current_token_location() else { + return Err(Error::end_of_input( + self.tokens.last().expect("last token exists").span, + )); + }; + + match token { + Token::Identifier(_) => self.named_or_generic_type(), + Token::LeftParentheses => self.tuple_type(), + _ => Err(Error::with_help( + format!("expected a type annotation, found `{token}`"), + *span, + "Use a valid type name or tuple type annotation in this position.".to_string(), + )), + } + } + + pub fn named_or_generic_type(&mut self) -> Result { + let Ok(TokenLocation { + token: Token::Identifier(ident), + span, + }) = self.require_current_token() + else { + unreachable!("this should have been checked"); + }; + + let generic_args = if self.peek_current_token() == Some(&Token::Less) { + self.delimited_type_params()? + } else { + Vec::new() + }; + + StaticType::from_name_and_args(ident.as_str(), generic_args) + .map_err(|err| Error::with_help(err.to_string(), span, err.help_text().to_string())) + } + + /// Parses `` type parameter lists, handling the `>>` / `>=` / `>>=` + /// ambiguity that arises with nested generics like `List>`. + fn delimited_type_params(&mut self) -> Result, Error> { + self.require_current_token_matches(&Token::Less)?; + + let mut items = vec![self.static_type()?]; + + while self.consume_token_if(&[Token::Comma]).is_some() { + if self.peek_current_token() == Some(&Token::Greater) { + break; + } + items.push(self.static_type()?); + } + + self.consume_closing_angle_bracket()?; + Ok(items) + } + + /// Consumes a closing `>` for a generic type parameter list. If the current + /// token is `>>`, `>=`, or `>>=`, it is split so that the leading `>` is + /// consumed and the remainder is left as the current token. + fn consume_closing_angle_bracket(&mut self) -> Result { + if let Some(token) = self.consume_token_if(&[Token::Greater]) { + return Ok(token.span); + } + + let Some(loc) = self.peek_current_token_location() else { + return Err(Error::end_of_input( + self.tokens.last().expect("last token exists").span, + )); + }; + + let greater_span = Span::new(loc.span.source_id(), loc.span.offset(), 1); + let rest_span = Span::new( + loc.span.source_id(), + loc.span.offset() + 1, + loc.span.end() - loc.span.offset() - 1, + ); + + let remainder = match &loc.token { + // >> becomes > + Token::GreaterGreater => Token::Greater, + // >= becomes = + Token::GreaterEquals => Token::EqualsSign, + // >>= (OpAssign(>>)) becomes >= + Token::OpAssign(inner) if inner.token == Token::GreaterGreater => Token::GreaterEquals, + _ => { + let loc = loc.clone(); + return Err(Error::text( + format!("Expected token '>' but got '{}' instead", loc.token), + loc.span, + )); + } + }; + + self.tokens[self.current] = TokenLocation { + token: remainder, + span: rest_span, + }; + + Ok(greater_span) + } + + pub fn tuple_type(&mut self) -> Result { + let (types, _span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + Self::static_type, + true, + )?; + Ok(StaticType::Tuple(types)) + } + + fn named_parameter(&mut self) -> Result { + let maybe_lvalue = self.single_expression()?; + let lvalue_span = maybe_lvalue.span; + + let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { + return Err(Error::with_help( + "Expected parameter name".to_string(), + lvalue_span, + "Function parameters must be identifiers, optionally followed by a type annotation (e.g. `x` or `x: Int`).".to_string(), + )); + }; + + let annotation = if self.peek_current_token() == Some(&Token::Colon) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; + + let span = if annotation.is_some() { + lvalue_span.merge(self.tokens[self.current - 1].span) + } else { + lvalue_span + }; + + Ok(FunctionParameter { + lvalue, + annotation, + span, + }) + } + + pub fn named_binding(&mut self) -> Result<(Lvalue, Option), Error> { + let maybe_lvalue = self.tuple_expression(Self::single_expression, false)?; + let lvalue_span = maybe_lvalue.span; + + let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { + return Err(Error::with_help( + "Invalid assignment target".to_string(), + lvalue_span, + "Assignment target is not a valid lvalue. Only a few expressions can be assigned a value. Check that the left-hand side of the assignment is a valid target.".to_string(), + )); + }; + + let annotated_type = if self.peek_current_token() == Some(&Token::Colon) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; + + Ok((lvalue, annotated_type)) + } + fn peek_range_end(&self) -> bool { matches!( self.peek_current_token(), @@ -1320,9 +1520,12 @@ pub struct Error { impl Error { #[must_use] - pub fn text(text: String, span: Span) -> Self { + pub fn text(text: S, span: Span) -> Self + where + S: Into, + { Self { - text, + text: text.into(), span, help_text: None, } diff --git a/ndc_stdlib/src/math.rs b/ndc_stdlib/src/math.rs index b2377dc0..5732c824 100644 --- a/ndc_stdlib/src/math.rs +++ b/ndc_stdlib/src/math.rs @@ -172,6 +172,7 @@ mod inner { pub mod f64 { use super::{Number, ToPrimitive, f64}; use ndc_core::StaticType; + use ndc_core::int::Int; use ndc_core::num::BinaryOperatorError; use ndc_vm::error::VmError; use ndc_vm::value::{NativeFunc, NativeFunction, Value}; @@ -241,6 +242,103 @@ pub mod f64 { "Returns the Euclidean remainder of dividing two numbers. The result is always non-negative." ); + // Int-specific overloads: fast path on i64, fall back to Number on overflow/BigInt. + macro_rules! implement_binary_operator_on_int { + ($operator:literal, $checked_method:ident, $fallback:expr, $docs:literal) => { + env.declare_global_fn(Rc::new(NativeFunction { + name: $operator.to_string(), + documentation: Some($docs.to_string()), + static_type: StaticType::Function { + parameters: Some(vec![StaticType::Int, StaticType::Int]), + return_type: Box::new(StaticType::Int), + }, + func: NativeFunc::Simple(Box::new(|args| match args { + [Value::Int(l), Value::Int(r)] => { + if let Some(result) = l.$checked_method(*r) { + Ok(Value::Int(result)) + } else { + let l = Int::Int64(*l); + let r = Int::Int64(*r); + Ok(Value::from_int($fallback(l, r))) + } + } + [left, right] => { + let l = left.to_int().ok_or_else(|| { + VmError::native(format!("expected int, got {}", left.static_type())) + })?; + let r = right.to_int().ok_or_else(|| { + VmError::native(format!( + "expected int, got {}", + right.static_type() + )) + })?; + Ok(Value::from_int($fallback(l, r))) + } + _ => Err(VmError::native(format!( + "expected 2 arguments, got {}", + args.len() + ))), + })), + })); + }; + } + + implement_binary_operator_on_int!( + "+", + checked_add, + std::ops::Add::add, + "Adds two integers." + ); + implement_binary_operator_on_int!( + "-", + checked_sub, + std::ops::Sub::sub, + "Subtracts two integers." + ); + implement_binary_operator_on_int!( + "*", + checked_mul, + std::ops::Mul::mul, + "Multiplies two integers." + ); + implement_binary_operator_on_int!( + "%", + checked_rem, + std::ops::Rem::rem, + "Returns the remainder of dividing two integers." + ); + + // Float-specific overloads: operate directly on f64. + macro_rules! implement_binary_operator_on_float { + ($operator:literal, $op:expr, $docs:literal) => { + env.declare_global_fn(Rc::new(NativeFunction { + name: $operator.to_string(), + documentation: Some($docs.to_string()), + static_type: StaticType::Function { + parameters: Some(vec![StaticType::Float, StaticType::Float]), + return_type: Box::new(StaticType::Float), + }, + func: NativeFunc::Simple(Box::new(|args| match args { + [Value::Float(l), Value::Float(r)] => Ok(Value::Float($op(*l, *r))), + _ => Err(VmError::native(format!( + "expected 2 float arguments, got {}", + args.len() + ))), + })), + })); + }; + } + + implement_binary_operator_on_float!("+", std::ops::Add::add, "Adds two floats."); + implement_binary_operator_on_float!("-", std::ops::Sub::sub, "Subtracts two floats."); + implement_binary_operator_on_float!("*", std::ops::Mul::mul, "Multiplies two floats."); + implement_binary_operator_on_float!("/", std::ops::Div::div, "Divides two floats."); + implement_binary_operator_on_float!( + "%", + std::ops::Rem::rem, + "Returns the remainder of dividing two floats." + ); + env.declare_global_fn(Rc::new(NativeFunction { name: "-".to_string(), documentation: Some("Negates a number.".to_string()), diff --git a/ndc_stdlib/src/sequence.rs b/ndc_stdlib/src/sequence.rs index 9443e712..5240880b 100644 --- a/ndc_stdlib/src/sequence.rs +++ b/ndc_stdlib/src/sequence.rs @@ -477,6 +477,7 @@ mod inner { } /// Returns `true` if the `predicate` is true for all the elements in `seq`. + #[function(return_type = bool)] pub fn all(seq: SeqValue, function: &mut VmCallable<'_>) -> anyhow::Result { for item in seq .try_into_iter() diff --git a/ndc_vm/src/compiler.rs b/ndc_vm/src/compiler.rs index 1ce86153..dddc0aa5 100644 --- a/ndc_vm/src/compiler.rs +++ b/ndc_vm/src/compiler.rs @@ -4,8 +4,8 @@ use crate::{Object, Value}; use ndc_core::{StaticType, TypeSignature}; use ndc_lexer::Span; use ndc_parser::{ - Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, LogicalOperator, - Lvalue, ResolvedVar, + Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, + FunctionParameter, LogicalOperator, Lvalue, ResolvedVar, }; use std::rc::Rc; @@ -152,7 +152,7 @@ impl Compiler { } } } - Expression::VariableDeclaration { value, l_value } => { + Expression::VariableDeclaration { value, l_value, .. } => { self.compile_expr(*value)?; self.compile_declare_lvalue(l_value, span)?; } @@ -299,12 +299,13 @@ impl Compiler { name, resolved_name, body, - type_signature, + parameters, return_type, captures, pure, .. } => { + let type_signature = FunctionParameter::from_params(¶meters); self.compile_function_decl( name, resolved_name, diff --git a/tests/programs/004_basic/046_annotated_let_binding.ndc b/tests/programs/004_basic/046_annotated_let_binding.ndc new file mode 100644 index 00000000..6b88f8df --- /dev/null +++ b/tests/programs/004_basic/046_annotated_let_binding.ndc @@ -0,0 +1,24 @@ +// This test asserts that supported annotated let bindings are valid syntax. +let any_value: Any = 3; +while false { + let never_value: Never = break; +} +let bool_value: Bool = true; +let int_value: Int = 3; +let float_value: Float = 3.0; +let rational_value: Number = 3 / 4; +let complex_value: Number = 1 + 2i; +let number_value: Number = 3; +let string_value: String = "hello"; + +let option_value: Option = Some(3); +let sequence_value: Sequence = [1, 2, 3]; +let list_value: List = [1, 2, 3]; +let iterator_value: Iterator = 1..10; +let min_heap_value: MinHeap = MinHeap(); +let max_heap_value: MaxHeap = MaxHeap(); +let deque_value: Deque = Deque(); +let map_value: Map = %{"a": 1, "b": 2}; +let tuple_named_value: Tuple = (1, "hello"); +let tuple_shorthand_value: (Int, String) = (1, "hello"); +let tuple_empty_value: () = (); diff --git a/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc b/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc new file mode 100644 index 00000000..5359df52 --- /dev/null +++ b/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found String but expected Int +let x: Int = "hello"; diff --git a/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc b/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc new file mode 100644 index 00000000..fd05d342 --- /dev/null +++ b/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found Bool but expected String +let x: String = true; diff --git a/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc b/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc new file mode 100644 index 00000000..7f292235 --- /dev/null +++ b/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found List but expected List +let x: List = [1, 2, 3]; diff --git a/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc b/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc new file mode 100644 index 00000000..06c3dd48 --- /dev/null +++ b/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found Tuple but expected Tuple +let x: (String, String) = (1, 2); diff --git a/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc b/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc new file mode 100644 index 00000000..7a231d3a --- /dev/null +++ b/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc @@ -0,0 +1,3 @@ +// expect-output: 42 +let x: Number = 42; +print(x); diff --git a/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc b/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc new file mode 100644 index 00000000..b9ee794f --- /dev/null +++ b/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc @@ -0,0 +1,3 @@ +// expect-error: mismatched types: found Number but expected Int +let x: Number = 3; +let y: Int = x; diff --git a/tests/programs/004_basic/053_nested_generics.ndc b/tests/programs/004_basic/053_nested_generics.ndc new file mode 100644 index 00000000..c47bf566 --- /dev/null +++ b/tests/programs/004_basic/053_nested_generics.ndc @@ -0,0 +1,19 @@ +// expect-output: [[1]] +// expect-output: [2] +// expect-output: [[3]] + +// This test ensures the parser correctly splits compound `>` tokens +// when closing nested generic type parameters. + +// >> is split into > > +let xs: List> = [[1]]; + +// >= is split into > = (no space before `=`) +let ys: List= [2]; + +// >>= is split into > >= then > = (no space before `=` with nested generics) +let zs: List>= [[3]]; + +print(xs); +print(ys); +print(zs); diff --git a/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc b/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc new file mode 100644 index 00000000..78dda3c0 --- /dev/null +++ b/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc @@ -0,0 +1,3 @@ +let x: Int = 5; +x = "test"; +// expect-error: mismatched types diff --git a/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc b/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc new file mode 100644 index 00000000..88cbc43e --- /dev/null +++ b/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc @@ -0,0 +1,3 @@ +let x: Int = 3; +x /= 4; +// expect-error: mismatched types diff --git a/tests/programs/005_functions/037_return_type_annotation.ndc b/tests/programs/005_functions/037_return_type_annotation.ndc new file mode 100644 index 00000000..c6374d61 --- /dev/null +++ b/tests/programs/005_functions/037_return_type_annotation.ndc @@ -0,0 +1,6 @@ +fn greet(name: String) -> String => "hello " <> name; +assert_eq(greet("world"), "hello world"); + +fn identity(x: Int) -> Int => x; +assert_eq(identity(5), 5); +// expect-output: diff --git a/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc b/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc new file mode 100644 index 00000000..90ace995 --- /dev/null +++ b/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc @@ -0,0 +1,2 @@ +fn bad() -> Int { "hello" } +// expect-error: mismatched types diff --git a/tests/programs/900_bugs/bug0017_function_parser_crash.ndc b/tests/programs/900_bugs/bug0017_function_parser_crash.ndc new file mode 100644 index 00000000..4d5ebada --- /dev/null +++ b/tests/programs/900_bugs/bug0017_function_parser_crash.ndc @@ -0,0 +1,2 @@ +// expect-error: Expected parameter name +fn x(1 + 1) { } diff --git a/tests/programs/900_bugs/bug0018_invalid_tuple_arity_binding.ndc b/tests/programs/900_bugs/bug0018_invalid_tuple_arity_binding.ndc new file mode 100644 index 00000000..8d252d1b --- /dev/null +++ b/tests/programs/900_bugs/bug0018_invalid_tuple_arity_binding.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched tuple arity: found a len=2 identifier and a len=1 annotation. +let (a, b): (Int) = (1, 2); \ No newline at end of file