diff --git a/rust/src/analyzer/definitions.rs b/rust/src/analyzer/definitions.rs index d97ac83..f7ad244 100644 --- a/rust/src/analyzer/definitions.rs +++ b/rust/src/analyzer/definitions.rs @@ -67,6 +67,8 @@ pub(crate) fn process_def_node( let method_name = String::from_utf8_lossy(def_node.name().as_slice()).to_string(); install_method(genv, method_name.clone()); + let merge_vtx = genv.scope_manager.current_method_return_vertex(); + // Process parameters BEFORE processing body let param_vtxs = if let Some(params_node) = def_node.parameters() { install_parameters(genv, lenv, changes, source, ¶ms_node) @@ -81,15 +83,22 @@ pub(crate) fn process_def_node( } } - // Register user-defined method with return vertex and param vertices (before exiting scope) - if let Some(return_vtx) = last_vtx { + // Connect last expression to merge vertex so that implicit return + // (Ruby's last-expression-is-return-value) is included in the union type + if let (Some(last), Some(merge)) = (last_vtx, merge_vtx) { + genv.add_edge(last, merge); + } + + // Register user-defined method with merge vertex as return vertex + let return_vtx = merge_vtx.or(last_vtx); + if let Some(ret_vtx) = return_vtx { let recv_type_name = genv.scope_manager.current_qualified_name(); if let Some(name) = recv_type_name { genv.register_user_method( Type::instance(&name), &method_name, - return_vtx, + ret_vtx, param_vtxs, ); } diff --git a/rust/src/analyzer/install.rs b/rust/src/analyzer/install.rs index 3841208..20150d4 100644 --- a/rust/src/analyzer/install.rs +++ b/rust/src/analyzer/install.rs @@ -14,6 +14,7 @@ use super::definitions::{process_class_node, process_def_node, process_module_no use super::dispatch::{dispatch_needs_child, dispatch_simple, process_needs_child, DispatchResult}; use super::literals::install_literal_node; use super::parentheses::process_parentheses_node; +use super::returns::process_return_node; /// Build graph from AST (public API wrapper) pub struct AstInstaller<'a> { @@ -81,6 +82,10 @@ pub(crate) fn install_node( return process_parentheses_node(genv, lenv, changes, source, &paren_node); } + if let Some(return_node) = node.as_return_node() { + return process_return_node(genv, lenv, changes, source, &return_node); + } + match dispatch_simple(genv, lenv, node) { DispatchResult::Vertex(vtx) => return Some(vtx), DispatchResult::NotHandled => {} diff --git a/rust/src/analyzer/mod.rs b/rust/src/analyzer/mod.rs index fdc2b2e..e195b59 100644 --- a/rust/src/analyzer/mod.rs +++ b/rust/src/analyzer/mod.rs @@ -8,6 +8,7 @@ mod install; mod literals; mod parameters; mod parentheses; +mod returns; mod variables; pub use install::AstInstaller; diff --git a/rust/src/analyzer/returns.rs b/rust/src/analyzer/returns.rs new file mode 100644 index 0000000..bfef14f --- /dev/null +++ b/rust/src/analyzer/returns.rs @@ -0,0 +1,191 @@ +//! Return statement handling +//! +//! Processes `return expr` by connecting the expression's vertex +//! to the enclosing method's merge vertex. + +use crate::env::{GlobalEnv, LocalEnv}; +use crate::graph::{ChangeSet, VertexId}; + +use super::install::install_node; + +/// Process ReturnNode: connect return value to method's merge vertex +pub(crate) fn process_return_node( + genv: &mut GlobalEnv, + lenv: &mut LocalEnv, + changes: &mut ChangeSet, + source: &str, + return_node: &ruby_prism::ReturnNode, +) -> Option { + // Process return value (first argument only; multi-value return not yet supported) + let value_vtx = if let Some(arguments) = return_node.arguments() { + arguments + .arguments() + .iter() + .next() + .and_then(|arg| install_node(genv, lenv, changes, source, &arg)) + } else { + // `return` without value → nil + Some(genv.new_source(crate::types::Type::Nil)) + }; + + // Connect return value to method's merge vertex + if let Some(vtx) = value_vtx { + if let Some(merge_vtx) = genv.scope_manager.current_method_return_vertex() { + genv.add_edge(vtx, merge_vtx); + } + } + + None +} + +#[cfg(test)] +mod tests { + use crate::env::{GlobalEnv, LocalEnv}; + use crate::graph::ChangeSet; + use crate::parser::ParseSession; + use crate::types::Type; + + fn setup_and_infer(source: &str) -> GlobalEnv { + let session = ParseSession::new(); + let parse_result = session.parse_source(source, "test.rb").unwrap(); + let root = parse_result.node(); + let program = root.as_program_node().unwrap(); + + let mut genv = GlobalEnv::new(); + let mut lenv = LocalEnv::new(); + let mut changes = ChangeSet::new(); + + for stmt in &program.statements().body() { + crate::analyzer::install::install_node( + &mut genv, &mut lenv, &mut changes, source, &stmt, + ); + } + + genv.apply_changes(changes); + genv.run_all(); + genv + } + + fn get_return_type(genv: &GlobalEnv, class_name: &str, method_name: &str) -> String { + let info = genv + .resolve_method(&Type::instance(class_name), method_name) + .unwrap_or_else(|| panic!("{}#{} should be registered", class_name, method_name)); + let vtx = info + .return_vertex + .expect("return_vertex should be Some"); + + if let Some(source) = genv.get_source(vtx) { + source.ty.show() + } else if let Some(vertex) = genv.get_vertex(vtx) { + vertex.show() + } else { + panic!("return_vertex not found"); + } + } + + #[test] + fn test_simple_return() { + let source = r#" +class Foo + def bar + return "hello" + end +end +"#; + let genv = setup_and_infer(source); + assert_eq!(get_return_type(&genv, "Foo", "bar"), "String"); + } + + #[test] + fn test_return_with_implicit_return_union() { + let source = r#" +class Foo + def bar + return "hello" if true + 42 + end +end +"#; + let genv = setup_and_infer(source); + let ty = get_return_type(&genv, "Foo", "bar"); + assert!(ty.contains("Integer"), "should contain Integer, got: {}", ty); + assert!(ty.contains("String"), "should contain String, got: {}", ty); + } + + #[test] + fn test_multiple_returns() { + let source = r#" +class Foo + def bar + return "a" if true + return :b if false + 42 + end +end +"#; + let genv = setup_and_infer(source); + let ty = get_return_type(&genv, "Foo", "bar"); + assert!(ty.contains("Integer"), "should contain Integer, got: {}", ty); + assert!(ty.contains("String"), "should contain String, got: {}", ty); + assert!(ty.contains("Symbol"), "should contain Symbol, got: {}", ty); + } + + #[test] + fn test_return_without_value() { + let source = r#" +class Foo + def bar + return if true + 42 + end +end +"#; + let genv = setup_and_infer(source); + let ty = get_return_type(&genv, "Foo", "bar"); + assert!(ty.contains("Integer"), "should contain Integer, got: {}", ty); + assert!(ty.contains("nil"), "should contain nil, got: {}", ty); + } + + #[test] + fn test_no_return_backward_compat() { + let source = r#" +class Foo + def bar + "hello" + end +end +"#; + let genv = setup_and_infer(source); + assert_eq!(get_return_type(&genv, "Foo", "bar"), "String"); + } + + #[test] + fn test_return_only_method() { + let source = r#" +class Foo + def bar + return "hello" + end +end +"#; + let genv = setup_and_infer(source); + assert_eq!(get_return_type(&genv, "Foo", "bar"), "String"); + } + + #[test] + fn test_return_dead_code_over_approximation() { + let source = r#" +class Foo + def bar + return "hello" + 42 + end +end +"#; + let genv = setup_and_infer(source); + let ty = get_return_type(&genv, "Foo", "bar"); + // Dead code after return is still processed (over-approximation) + assert!(ty.contains("Integer"), "should contain Integer (dead code), got: {}", ty); + assert!(ty.contains("String"), "should contain String, got: {}", ty); + } +} diff --git a/rust/src/env/global_env.rs b/rust/src/env/global_env.rs index 10c4ec5..f5e7b36 100644 --- a/rust/src/env/global_env.rs +++ b/rust/src/env/global_env.rs @@ -210,9 +210,11 @@ impl GlobalEnv { pub fn enter_method(&mut self, name: String) -> ScopeId { // Look for class or module context let receiver_type = self.scope_manager.current_qualified_name(); + let return_vertex = Some(self.new_vertex()); let scope_id = self.scope_manager.new_scope(ScopeKind::Method { name, receiver_type, + return_vertex, }); self.scope_manager.enter_scope(scope_id); scope_id diff --git a/rust/src/env/scope.rs b/rust/src/env/scope.rs index 72397fb..6a1f371 100644 --- a/rust/src/env/scope.rs +++ b/rust/src/env/scope.rs @@ -20,6 +20,7 @@ pub enum ScopeKind { Method { name: String, receiver_type: Option, // Receiver class/module name + return_vertex: Option, // Merge vertex for return statements }, Block, } @@ -309,6 +310,22 @@ impl ScopeManager { Some(result) } + /// Get return_vertex from the nearest enclosing method scope + pub fn current_method_return_vertex(&self) -> Option { + let mut current = Some(self.current_scope); + while let Some(scope_id) = current { + if let Some(scope) = self.scopes.get(&scope_id) { + if let ScopeKind::Method { return_vertex, .. } = &scope.kind { + return *return_vertex; + } + current = scope.parent; + } else { + break; + } + } + None + } + /// Lookup instance variable in enclosing module scope pub fn lookup_instance_var_in_module(&self, name: &str) -> Option { let mut current = Some(self.current_scope); @@ -448,6 +465,7 @@ mod tests { let method_id = sm.new_scope(ScopeKind::Method { name: "test".to_string(), receiver_type: None, + return_vertex: None, }); sm.enter_scope(method_id); @@ -472,6 +490,7 @@ mod tests { let method_id = sm.new_scope(ScopeKind::Method { name: "helper".to_string(), receiver_type: Some("Utils".to_string()), + return_vertex: None, }); sm.enter_scope(method_id); @@ -500,6 +519,7 @@ mod tests { let method_id = sm.new_scope(ScopeKind::Method { name: "get_setting".to_string(), receiver_type: Some("Config".to_string()), + return_vertex: None, }); sm.enter_scope(method_id); @@ -559,6 +579,7 @@ mod tests { let method_id = sm.new_scope(ScopeKind::Method { name: "greet".to_string(), receiver_type: None, + return_vertex: None, }); sm.enter_scope(method_id); diff --git a/test/return_test.rb b/test/return_test.rb new file mode 100644 index 0000000..c4918ee --- /dev/null +++ b/test/return_test.rb @@ -0,0 +1,132 @@ +# frozen_string_literal: true + +require 'test_helper' + +class ReturnTest < Minitest::Test + include CLITestHelper + + # ============================================ + # No Error (check CLI) + # ============================================ + + def test_return_string_upcase_no_error + source = <<~RUBY + class Formatter + def format + return "" if true + "default" + end + + def run + self.format.upcase + end + end + RUBY + + assert_no_check_errors(source) + end + + def test_return_only_method_no_error + source = <<~RUBY + class Foo + def bar + return "hello" + end + + def baz + self.bar.upcase + end + end + RUBY + + assert_no_check_errors(source) + end + + def test_no_return_backward_compat + source = <<~RUBY + class Foo + def bar + "hello" + end + + def baz + self.bar.upcase + end + end + RUBY + + assert_no_check_errors(source) + end + + def test_multiple_returns_union_no_error + source = <<~RUBY + class Converter + def convert + return "hello" if true + return "world" if false + "default" + end + + def run + self.convert.upcase + end + end + RUBY + + assert_no_check_errors(source) + end + + def test_return_with_implicit_return_union_no_error + source = <<~RUBY + class Calculator + def compute + return 0 if true + 42 + end + + def run + self.compute.even? + end + end + RUBY + + assert_no_check_errors(source) + end + + # ============================================ + # Error Detection (check CLI) + # ============================================ + + def test_return_string_even_error + source = <<~RUBY + class Parser + def parse + return "error" + end + + def run + self.parse.even? + end + end + RUBY + + assert_check_error(source, method_name: 'even?', receiver_type: 'String') + end + + def test_return_union_string_integer_even_error + source = <<~RUBY + class Validator + def validate + return "invalid" if true + 42 + end + + def run + self.validate.even? + end + end + RUBY + + assert_check_error(source, method_name: 'even?', receiver_type: 'String') + end +end