diff --git a/core/engine/src/handler/function/module/mf.rs b/core/engine/src/handler/function/module/mf.rs new file mode 100644 index 00000000..f029d27f --- /dev/null +++ b/core/engine/src/handler/function/module/mf.rs @@ -0,0 +1,410 @@ +//! # 自定义函数监听器模块 +//! +//! 本模块实现了CustomListener,用于在JavaScript运行时环境中注册和管理自定义函数。 +//! +//! ## 主要功能 +//! +//! - **函数注册**: 在运行时启动时自动将Rust自定义函数注册到JavaScript的md命名空间 +//! - **命名空间管理**: 创建和管理md作用域,避免全局命名冲突 +//! - **类型转换**: 处理Rust和JavaScript之间的数据类型转换 +//! - **异步支持**: 提供异步函数调用支持,确保不阻塞JavaScript执行 +//! - **错误处理**: 完善的错误捕获和处理机制 +//! +//! ## 使用场景 +//! +//! 该监听器主要用于规则引擎中,允许在规则表达式中通过`md.functionName()`的形式 +//! 调用预定义的Rust函数,从而扩展JavaScript运行时的功能。 +//! +//! ## 架构说明 +//! +//! ```text +//! MfFunctionRegistry → CustomListener → JavaScript Runtime (md namespace) +//! ↓ ↓ ↓ +//! 函数定义存储 函数注册处理 md.functionName() 调用执行 +//! ``` + +use std::future::Future; +use std::pin::Pin; +use crate::handler::function::error::{FunctionResult, ResultExt}; +use crate::handler::function::listener::{RuntimeEvent, RuntimeListener}; +use crate::handler::function::module::export_default; +use crate::handler::function::serde::JsValue; +use rquickjs::module::{Declarations, Exports, ModuleDef}; +use rquickjs::prelude::{Async, Func}; +use rquickjs::{CatchResultExt, Ctx}; +use zen_expression::functions::arguments::Arguments; +use zen_expression::functions::mf_function::MfFunctionRegistry; + +/// 自定义函数监听器 +/// +/// 该监听器负责在JavaScript运行时启动时,将所有注册的自定义函数 +/// 绑定到JavaScript的md命名空间中,使得这些函数可以在规则表达式中通过 +/// `md.functionName()`的形式被调用 +/// +/// # 工作流程 +/// 1. 监听运行时启动事件 +/// 2. 创建或获取md命名空间对象 +/// 3. 从CustomFunctionRegistry获取所有已注册的函数 +/// 4. 将每个函数包装为异步JavaScript函数 +/// 5. 注册到JavaScript的md命名空间中 +pub struct ModuforgeListener { + // 目前为空结构体,后续可以添加配置或状态字段 +} + +impl RuntimeListener for ModuforgeListener { + /// 处理运行时事件的核心方法 + /// + /// # 参数 + /// - `ctx`: QuickJS上下文,用于操作JavaScript环境 + /// - `event`: 运行时事件类型 + /// + /// # 返回值 + /// 返回一个异步Future,包含操作结果 + fn on_event<'js>( + &self, + ctx: Ctx<'js>, + event: RuntimeEvent, + ) -> Pin + 'js>> { + Box::pin(async move { + // 只在运行时启动事件时执行函数注册 + if event != RuntimeEvent::Startup { + return Ok(()); + }; + + // 设置全局函数及变量 + // 创建或获取 md 命名空间对象 + let md_namespace = if ctx.globals().contains_key("md")? { + // 如果 md 已存在,获取它 + ctx.globals().get("md")? + } else { + // 如果 md 不存在,创建一个新的空对象 + let md_obj = rquickjs::Object::new(ctx.clone())?; + ctx.globals().set("md", md_obj.clone())?; + md_obj + }; + + // 从自定义函数注册表中获取所有函数名称 + let functions_keys = MfFunctionRegistry::list_functions(); + + // 遍历每个注册的函数 + for function_key in functions_keys { + // 根据函数名获取函数定义 + let function_definition = + MfFunctionRegistry::get_definition(&function_key); + + if let Some(function_definition) = function_definition { + // 将Rust函数包装为JavaScript异步函数并注册到md命名空间下 + + let function_definition = function_definition.clone(); + let parameters = function_definition.required_parameters(); + match parameters { + 0 => { + md_namespace + .set( + function_key, // 函数名作为md对象的属性名 + Func::from(Async(move |ctx: Ctx<'js>| { + // 克隆函数定义以避免生命周期问题 + let function_definition = + function_definition.clone(); + + async move { + // 调用Rust函数,传入JavaScript参数 + let response = function_definition + .call(Arguments(&[])) + .or_throw(&ctx)?; + + // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值 + let k = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + return rquickjs::Result::Ok( + JsValue(k), + ); + } + })), + ) + .catch(&ctx)?; // 捕获并处理可能的JavaScript异常 + }, + 1 => { + md_namespace + .set( + function_key, // 函数名作为md对象的属性名 + Func::from(Async( + move |ctx: Ctx<'js>, context: JsValue| { + // 克隆函数定义以避免生命周期问题 + let function_definition = + function_definition.clone(); + async move { + // 调用Rust函数,传入JavaScript参数 + let response = function_definition + .call(Arguments(&[context.0])) + .or_throw(&ctx)?; + // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值 + let k = serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + return rquickjs::Result::Ok(JsValue( + k, + )); + } + }, + )), + ) + .catch(&ctx)?; // 捕获并处理可能的JavaScript异常 + }, + 2 => { + md_namespace + .set( + function_key, // 函数名作为md对象的属性名 + Func::from(Async( + move |ctx: Ctx<'js>, context: JsValue,context2: JsValue| { + // 克隆函数定义以避免生命周期问题 + let function_definition = + function_definition.clone(); + async move { + // 调用Rust函数,传入JavaScript参数 + let response = function_definition + .call(Arguments(&[context.0,context2.0])) + .or_throw(&ctx)?; + // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值 + let k = serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + return rquickjs::Result::Ok(JsValue( + k, + )); + } + }, + )), + ) + .catch(&ctx)?; // 捕获并处理可能的JavaScript异常 + }, + 3 => { + md_namespace + .set( + function_key, // 函数名作为md对象的属性名 + Func::from(Async( + move |ctx: Ctx<'js>, context: JsValue,context2: JsValue,context3: JsValue| { + // 克隆函数定义以避免生命周期问题 + let function_definition = + function_definition.clone(); + async move { + // 调用Rust函数,传入JavaScript参数 + let response = function_definition + .call(Arguments(&[context.0,context2.0,context3.0])) + .or_throw(&ctx)?; + // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值 + let k: zen_expression::Variable = serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + return rquickjs::Result::Ok(JsValue( + k, + )); + } + }, + )), + ) + .catch(&ctx)?; // 捕获并处理可能的JavaScript异常 + }, + _ => { + md_namespace + .set( + function_key, // 函数名作为md对象的属性名 + Func::from(Async( + move |ctx: Ctx<'js>, context: Vec| { + // 克隆函数定义以避免生命周期问题 + let function_definition = + function_definition.clone(); + async move { + // 调用Rust函数,传入JavaScript参数 + let response = function_definition + .call(Arguments(&context.iter().map(|arg| arg.0.clone()).collect::>())) + .or_throw(&ctx)?; + // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值 + let k = serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + return rquickjs::Result::Ok(JsValue( + k, + )); + } + }, + )), + ) + .catch(&ctx)?; // 捕获并处理可能的JavaScript异常 + }, + } + } + } + + Ok(()) // 成功完成函数注册 + }) + } +} + +pub struct ModuforgeModule; + +impl ModuleDef for ModuforgeModule { + fn declare<'js>(decl: &Declarations<'js>) -> rquickjs::Result<()> { + // 声明所有可用的函数 + for function_key in MfFunctionRegistry::list_functions() { + decl.declare(function_key.as_str())?; + } + decl.declare("default")?; + Ok(()) + } + + fn evaluate<'js>( + ctx: &Ctx<'js>, + exports: &Exports<'js>, + ) -> rquickjs::Result<()> { + export_default(ctx, exports, |default| { + // 为每个函数创建对应的异步函数 + for function_key in MfFunctionRegistry::list_functions() { + if let Some(function_definition) = + MfFunctionRegistry::get_definition(&function_key) + { + let function_definition = function_definition.clone(); + let parameters = function_definition.required_parameters(); + match parameters { + 0 => { + default.set( + &function_key, + Func::from(Async(move |ctx: Ctx<'js>| { + let function_definition = + function_definition.clone(); + async move { + let response = function_definition + .call(Arguments(&[])) + .or_throw(&ctx)?; + + let result = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + Ok::(JsValue( + result, + )) + } + })), + )?; + }, + 1 => { + //只有一个参数 + default.set( + &function_key, + Func::from(Async( + move |ctx: Ctx<'js>, args: JsValue| { + let function_definition = + function_definition.clone(); + async move { + let response = function_definition + .call(Arguments(&[args.0])) + .or_throw(&ctx)?; + + let result = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + Ok::( + JsValue(result), + ) + } + }, + )), + )?; + }, + 2 => { + //有两个参数 + default.set( + &function_key, + Func::from(Async( + move |ctx: Ctx<'js>, args: JsValue,args2: JsValue| { + let function_definition = + function_definition.clone(); + async move { + let response = function_definition + .call(Arguments(&[args.0,args2.0])) + .or_throw(&ctx)?; + + let result = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + Ok::( + JsValue(result), + ) + } + }, + )), + )?; + }, + 3 => { + //有三个参数 + default.set( + &function_key, + Func::from(Async( + move |ctx: Ctx<'js>, args: JsValue,args2: JsValue,args3: JsValue| { + let function_definition = + function_definition.clone(); + async move { + let response = function_definition + .call(Arguments(&[args.0,args2.0,args3.0])) + .or_throw(&ctx)?; + + let result = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + Ok::( + JsValue(result), + ) + } + }, + )), + )?; + }, + _ => { + //4个以上参数 的参数必须以数组的形式传入 + default.set( + &function_key, + Func::from(Async( + move |ctx: Ctx<'js>, args: Vec| { + let function_definition = + function_definition.clone(); + async move { + let args_vec = args + .iter() + .map(|arg| arg.0.clone()) + .collect::>(); + let response = function_definition + .call(Arguments(&args_vec)) + .or_throw(&ctx)?; + + let result = + serde_json::to_value(response) + .or_throw(&ctx)? + .into(); + + Ok::( + JsValue(result), + ) + } + }, + )), + )?; + }, + } + } + } + + Ok(()) + }) + } +} diff --git a/core/engine/src/handler/function/module/mod.rs b/core/engine/src/handler/function/module/mod.rs index 180afb31..5a84acce 100644 --- a/core/engine/src/handler/function/module/mod.rs +++ b/core/engine/src/handler/function/module/mod.rs @@ -8,11 +8,13 @@ use rquickjs::module::{Declared, Exports}; use rquickjs::{embed, Ctx, Error, Module, Object}; use crate::handler::function::module::http::HttpModule; +use crate::handler::function::module::mf::ModuforgeModule; use crate::handler::function::module::zen::ZenModule; pub(crate) mod console; pub(crate) mod http; pub(crate) mod zen; +pub mod mf; static JS_BUNDLE: Bundle = embed! { "dayjs": "js/dayjs.mjs", @@ -61,7 +63,7 @@ struct BaseModuleLoader { impl BaseModuleLoader { pub fn new() -> Self { - let mut hs = HashSet::from(["zen".to_string(), "http".to_string()]); + let mut hs = HashSet::from(["zen".to_string(), "http".to_string(),"mf".to_string()]); JS_BUNDLE.iter().for_each(|(key, _)| { hs.insert(key.to_string()); @@ -72,7 +74,8 @@ impl BaseModuleLoader { defined_modules: RefCell::new(hs), md_loader: MDLoader::default() .with_module("zen", ZenModule) - .with_module("http", HttpModule), + .with_module("http", HttpModule) + .with_module("mf", ModuforgeModule), } } diff --git a/core/engine/src/handler/graph.rs b/core/engine/src/handler/graph.rs index f74389dc..587dedad 100644 --- a/core/engine/src/handler/graph.rs +++ b/core/engine/src/handler/graph.rs @@ -3,6 +3,7 @@ use crate::handler::decision::DecisionHandler; use crate::handler::expression::ExpressionHandler; use crate::handler::function::function::{Function, FunctionConfig}; use crate::handler::function::module::console::ConsoleListener; +use crate::handler::function::module::mf::ModuforgeListener; use crate::handler::function::module::zen::ZenListener; use crate::handler::function::FunctionHandler; use crate::handler::function_v1; @@ -110,6 +111,7 @@ impl DecisionGraph< loader: self.loader.clone(), adapter: self.adapter.clone(), }), + Box::new(ModuforgeListener{}) ]), }) .await diff --git a/core/expression/examples/compilation_demo.rs b/core/expression/examples/compilation_demo.rs new file mode 100644 index 00000000..ef078a2b --- /dev/null +++ b/core/expression/examples/compilation_demo.rs @@ -0,0 +1,197 @@ +use zen_expression::{Isolate, evaluate_expression, Variable}; +use serde_json::json; +use rust_decimal_macros::dec; + +fn main() { + println!("=== rules_expression 编译与VM执行过程演示 ===\n"); + + // 演示1: 基础表达式编译和执行 + demo_basic_compilation(); + + // 演示2: 复杂表达式的字节码分析 + demo_complex_expression(); + + // 演示3: 高性能重复执行 + demo_performance_execution(); + + // 演示4: 不同数据类型的处理 + demo_data_types(); + + // 演示5: 区间和条件表达式 + demo_intervals_and_conditions(); +} + +fn demo_basic_compilation() { + println!("【演示1: 基础表达式编译和执行】"); + + // 创建上下文环境 + let context = json!({ + "tax": { + "percentage": 10 + }, + "amount": 50 + }); + + let expression = "amount * tax.percentage / 100"; + println!("表达式: {}", expression); + println!("上下文: {}", context); + + // 方式1: 直接评估(内部完成完整的编译->执行流程) + let result = + evaluate_expression(expression, context.clone().into()).unwrap(); + println!("计算结果: {:?}", result); + + // 方式2: 使用Isolate查看详细过程 + let mut isolate = Isolate::with_environment(context.into()); + + // 编译表达式获取字节码 + let compiled = isolate.compile_standard(expression).unwrap(); + println!("编译后的字节码: {:?}", compiled.bytecode()); + + // 执行编译后的表达式 + let new_context = json!({"tax": {"percentage": 15}, "amount": 100}); + let result2 = compiled.evaluate(new_context.into()).unwrap(); + println!("新上下文执行结果: {:?}\n", result2); +} + +fn demo_complex_expression() { + println!("【演示2: 复杂表达式的字节码分析】"); + + let mut isolate = Isolate::new(); + let expression = "(a + b) * c - d / 2"; + + println!("复杂表达式: {}", expression); + + // 编译并查看字节码 + let compiled = isolate.compile_standard(expression).unwrap(); + println!("生成的字节码指令:"); + for (i, opcode) in compiled.bytecode().iter().enumerate() { + println!(" {}: {:?}", i, opcode); + } + + // 执行演示 + let context = json!({"a": 10, "b": 20, "c": 3, "d": 8}); + println!("执行上下文: {}", context); + + let result = compiled.evaluate(context.into()).unwrap(); + println!("计算结果: {:?}", result); + println!("验证: (10 + 20) * 3 - 8 / 2 = 30 * 3 - 4 = 90 - 4 = 86\n"); +} + +fn demo_performance_execution() { + println!("【演示3: 高性能重复执行】"); + + let context = json!({ + "items": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "factor": 0.1 + }); + + let mut isolate = Isolate::with_environment(context.into()); + + // 预编译表达式 + let expression = "sum(items) * factor"; + let compiled = isolate.compile_standard(expression).unwrap(); + + println!("表达式: {}", expression); + println!("预编译完成,开始高性能重复执行..."); + + // 模拟高频执行 + let iterations = 100_000; + let start = std::time::Instant::now(); + + for _ in 0..iterations { + // 重复使用预编译的字节码,VM重用内存 + let _result = isolate.run_standard(expression).unwrap(); + } + + let duration = start.elapsed(); + println!("执行 {} 次耗时: {:?}", iterations, duration); + println!("平均每次执行: {:?}", duration / iterations); + println!( + "每秒执行次数: {:.0}\n", + iterations as f64 / duration.as_secs_f64() + ); +} + +fn demo_data_types() { + println!("【演示4: 不同数据类型处理】"); + + let context = json!({ + "user": { + "name": "Alice", + "age": 25, + "active": true, + "scores": [85, 92, 78, 96] + }, + "settings": { + "threshold": 80 + } + }); + + let mut isolate = Isolate::with_environment(context.into()); + + let test_cases = vec![ + ( + "user.name + \" is \" + string(user.age) + \" years old\"", + "字符串拼接", + ), + ("user.age >= 18", "布尔运算"), + ("len(user.scores)", "数组长度"), + ("max(user.scores)", "数组最大值"), + ("avg(user.scores) > settings.threshold", "数组平均值比较"), + ]; + + for (expr, desc) in test_cases { + let result = isolate.run_standard(expr).unwrap(); + println!("{}: {} = {:?}", desc, expr, result); + } + println!(); +} + +fn demo_intervals_and_conditions() { + println!("【演示5: 区间和条件表达式】"); + + let context = json!({ + "student": { + "age": 20, + "score": 85, + "grade": "B+" + }, + "rules": { + "adult_age": 18, + "passing_score": 60, + "excellent_score": 90 + } + }); + + let mut isolate = Isolate::with_environment(context.into()); + + let test_expressions = vec![ + ("student.age >= rules.adult_age", "成年判断"), + ("student.score in [rules.passing_score..100]", "及格区间判断"), + ("student.score in [rules.excellent_score..100]", "优秀区间判断"), + ( + "student.score in (rules.passing_score..rules.excellent_score)", + "良好区间判断(开区间)", + ), + ( + "student.age in [18..65) and student.score >= rules.passing_score", + "复合条件", + ), + ]; + + for (expr, desc) in test_expressions { + let result = isolate.run_standard(expr).unwrap(); + println!("{}: {} = {:?}", desc, expr, result); + } + + // 演示区间转换为数组 + println!("\n区间数组转换演示:"); + let range_expr = "[1..5]"; + let compiled = isolate.compile_standard(range_expr).unwrap(); + println!("区间表达式: {}", range_expr); + println!("字节码: {:?}", compiled.bytecode()); + + let result = compiled.evaluate(Variable::empty_object()).unwrap(); + println!("区间结果: {:?}", result); +} diff --git a/core/expression/examples/custom_function_demo.rs b/core/expression/examples/custom_function_demo.rs new file mode 100644 index 00000000..ffdcb0d2 --- /dev/null +++ b/core/expression/examples/custom_function_demo.rs @@ -0,0 +1,86 @@ +use zen_expression::{Isolate, Variable}; +use zen_expression::functions::mf_function::{ + MfFunctionHelper, MfFunctionRegistry, +}; +use zen_expression::variable::VariableType; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +// 1. 定义一个简单的、我们自己的状态 +#[derive(Debug)] +struct MyState { + call_count: Mutex, +} + +impl MyState { + fn new() -> Self { + Self { call_count: Mutex::new(0) } + } + + fn increment(&self) -> u32 { + let mut count = self.call_count.lock().unwrap(); + *count += 1; + *count + } +} + +fn main() -> anyhow::Result<()> { + println!("=== 自定义函数与泛型State集成演示 ===\n"); + + // === 第一部分: 演示使用我们自定义的 `MyState` === + println!("--- 场景1: 使用自定义的 MyState ---"); + let my_state = Arc::new(MyState::new()); + + // 2. 为 `MyState` 创建一个 Helper + let my_helper = MfFunctionHelper::::new(); + + // 3. 注册一个可以访问 `MyState` 的函数 + println!("注册函数: getMyStateCallCount()"); + my_helper + .register_function( + "getMyStateCallCount".to_string(), + vec![], + VariableType::Number, + Box::new(|_args, state_opt: Option<&MyState>| { + if let Some(state) = state_opt { + // `state` 的类型是 &MyState + let count = state.increment(); + Ok(Variable::Number(count.into())) + } else { + Ok(Variable::Number((-1i32).into())) + } + }), + ) + .map_err(|e| anyhow::anyhow!(e))?; + + // 4. 创建 Isolate 并使用 `MyState` 执行表达式 + let mut isolate = Isolate::new(); + println!("使用 `MyState` 执行 'getMyStateCallCount()'"); + let result1 = isolate + .run_standard_with_state("getMyStateCallCount()", my_state.clone())?; + println!(" 第一次调用结果: {}", result1); + let result2 = isolate + .run_standard_with_state("getMyStateCallCount()", my_state.clone())?; + println!(" 第二次调用结果: {}", result2); + + // === 第三部分: 验证两种函数可以共存 === + println!("\n--- 场景3: 验证两种状态的函数可以共存 ---"); + println!("再次调用 `getMyStateCallCount` (应为3)"); + let result4 = isolate + .run_standard_with_state("getMyStateCallCount()", my_state.clone())?; + println!(" 结果: {}", result4); + + // 显示所有已注册的自定义函数 + println!("\n=== 已注册的自定义函数 ==="); + let functions = MfFunctionRegistry::list_functions(); + for func in functions { + println!("- {}", func); + } + + // 清理 + println!("\n清理所有自定义函数..."); + MfFunctionRegistry::clear(); + + println!("演示完成!"); + Ok(()) +} diff --git a/core/expression/src/compiler/compiler.rs b/core/expression/src/compiler/compiler.rs index b03fae0a..bd997309 100644 --- a/core/expression/src/compiler/compiler.rs +++ b/core/expression/src/compiler/compiler.rs @@ -396,7 +396,7 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { }), }, Node::FunctionCall { kind, arguments } => match kind { - FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => { + FunctionKind::Internal(_) | FunctionKind::Deprecated(_) | FunctionKind::Mf(_) => { let function = FunctionRegistry::get_definition(kind).ok_or_else(|| { CompilerError::UnknownFunction { name: kind.to_string(), diff --git a/core/expression/src/functions/mf_function.rs b/core/expression/src/functions/mf_function.rs new file mode 100644 index 00000000..a01b7492 --- /dev/null +++ b/core/expression/src/functions/mf_function.rs @@ -0,0 +1,319 @@ +//! 自定义函数模块 +//! +//! 支持在运行时动态注册自定义函数,并可以访问State + +use crate::functions::defs::{ + FunctionDefinition, FunctionSignature, StaticFunction, +}; +use crate::functions::arguments::Arguments; +use crate::variable::{Variable, VariableType}; +use std::rc::Rc; +use std::sync::Arc; +use std::collections::HashMap; +use std::cell::RefCell; +use std::fmt::Display; +use anyhow::Result as AnyhowResult; +use std::any::Any; +use std::marker::PhantomData; + +/// 自定义函数标识符 +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct MfFunction { + /// 函数名称 + pub name: String, +} + +impl MfFunction { + pub fn new(name: String) -> Self { + Self { name } + } +} + +impl Display for MfFunction { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl TryFrom<&str> for MfFunction { + type Error = strum::ParseError; + + fn try_from(value: &str) -> Result { + // 检查是否为已注册的自定义函数 + if MfFunctionRegistry::is_registered(value) { + Ok(MfFunction::new(value.to_string())) + } else { + Err(strum::ParseError::VariantNotFound) + } + } +} + +/// 自定义函数的内部执行器类型 (类型擦除) +type ErasedExecutor = Box< + dyn Fn( + &Arguments, + Option<&Arc>, + ) -> AnyhowResult + + 'static, +>; + +/// 自定义函数定义 +pub struct MfFunctionDefinition { + /// 函数名称 + pub name: String, + /// 函数签名 + pub signature: FunctionSignature, + /// 执行器 + pub executor: ErasedExecutor, +} + +impl MfFunctionDefinition { + pub fn new( + name: String, + signature: FunctionSignature, + executor: ErasedExecutor, + ) -> Self { + Self { name, signature, executor } + } +} + +impl FunctionDefinition for MfFunctionDefinition { + fn call( + &self, + args: Arguments, + ) -> AnyhowResult { + // 尝试获取State上下文(如果可用) + let state = CURRENT_STATE.with(|s| s.borrow().clone()); + (self.executor)(&args, state.as_ref()) + } + + fn required_parameters(&self) -> usize { + self.signature.parameters.len() + } + + fn optional_parameters(&self) -> usize { + 0 // 暂时不支持可选参数 + } + + fn check_types( + &self, + args: &[VariableType], + ) -> crate::functions::defs::FunctionTypecheck { + let mut typecheck = + crate::functions::defs::FunctionTypecheck::default(); + typecheck.return_type = self.signature.return_type.clone(); + + if args.len() != self.required_parameters() { + typecheck.general = Some(format!( + "期望 `{}` 参数, 实际 `{}` 参数.", + self.required_parameters(), + args.len() + )); + } + + // 检查每个参数类型 + for (i, (arg, expected_type)) in + args.iter().zip(self.signature.parameters.iter()).enumerate() + { + if !arg.satisfies(expected_type) { + typecheck.arguments.push(( + i, + format!( + "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.", + ), + )); + } + } + + typecheck + } + + fn param_type( + &self, + index: usize, + ) -> Option { + self.signature.parameters.get(index).cloned() + } + + fn param_type_str( + &self, + index: usize, + ) -> String { + self.signature + .parameters + .get(index) + .map(|x| x.to_string()) + .unwrap_or_else(|| "never".to_string()) + } + + fn return_type(&self) -> VariableType { + self.signature.return_type.clone() + } + + fn return_type_str(&self) -> String { + self.signature.return_type.to_string() + } +} + +thread_local! { + /// 当前State上下文(用于自定义函数访问) + static CURRENT_STATE: RefCell>> = RefCell::new(None); +} + +/// 自定义函数注册表 +pub struct MfFunctionRegistry { + functions: HashMap>, +} + +impl MfFunctionRegistry { + thread_local!( + static INSTANCE: RefCell = RefCell::new(MfFunctionRegistry::new()) + ); + + fn new() -> Self { + Self { functions: HashMap::new() } + } + + /// 注册自定义函数 (内部使用) + fn register_function_erased( + name: String, + signature: FunctionSignature, + executor: ErasedExecutor, + ) -> Result<(), String> { + Self::INSTANCE.with(|registry| { + let mut reg = registry.borrow_mut(); + if reg.functions.contains_key(&name) { + return Err(format!("函数 '{}' 已经存在", name)); + } + + let definition = MfFunctionDefinition::new( + name.clone(), + signature, + executor, + ); + reg.functions.insert(name, Rc::new(definition)); + Ok(()) + }) + } + + /// 获取函数定义 + pub fn get_definition(name: &str) -> Option> { + Self::INSTANCE.with(|registry| { + registry + .borrow() + .functions + .get(name) + .map(|def| def.clone() as Rc) + }) + } + + /// 检查函数是否已注册 + pub fn is_registered(name: &str) -> bool { + Self::INSTANCE + .with(|registry| registry.borrow().functions.contains_key(name)) + } + + /// 设置当前State上下文 + pub fn set_current_state(state: Option>) { + CURRENT_STATE.with(|s| { + *s.borrow_mut() = state.map(|st| st as Arc); + }); + } + + /// 检查当前是否有活跃的State + pub fn has_current_state() -> bool { + CURRENT_STATE.with(|s| s.borrow().is_some()) + } + + /// 清理当前State上下文 + pub fn clear_current_state() { + CURRENT_STATE.with(|s| { + *s.borrow_mut() = None; + }); + } + + /// 列出所有已注册的函数 + pub fn list_functions() -> Vec { + Self::INSTANCE.with(|registry| { + registry.borrow().functions.keys().cloned().collect() + }) + } + + /// 清空所有注册的函数 + pub fn clear() { + Self::INSTANCE.with(|registry| { + registry.borrow_mut().functions.clear(); + }); + } +} + +/// 用于注册特定状态类型 `S` 的函数的辅助结构。 +pub struct MfFunctionHelper { + _marker: PhantomData, +} + +impl MfFunctionHelper { + /// 创建一个新的辅助实例。 + pub fn new() -> Self { + Self { _marker: PhantomData } + } + + /// 注册一个自定义函数。 + /// + /// # Parameters + /// - `name`: 函数名。 + /// - `params`: 函数参数类型列表。 + /// - `return_type`: 函数返回类型。 + /// - `executor`: 函数的实现,它接收参数和一个可选的 `Arc` 状态引用。 + pub fn register_function( + &self, + name: String, + params: Vec, + return_type: VariableType, + executor: Box< + dyn Fn(&Arguments, Option<&S>) -> AnyhowResult + 'static, + >, + ) -> Result<(), String> { + let signature = FunctionSignature { parameters: params, return_type }; + + let wrapped_executor: ErasedExecutor = + Box::new(move |args, state_any| { + let typed_state = state_any.and_then(|s| s.downcast_ref::()); + executor(args, typed_state) + }); + + MfFunctionRegistry::register_function_erased( + name, + signature, + wrapped_executor, + ) + } +} + +impl Default for MfFunctionHelper { + fn default() -> Self { + Self::new() + } +} + +impl From<&MfFunction> for Rc { + fn from(custom: &MfFunction) -> Self { + MfFunctionRegistry::get_definition(&custom.name).unwrap_or_else( + || { + // 如果函数不存在,返回一个错误函数 + Rc::new(StaticFunction { + signature: FunctionSignature { + parameters: vec![], + return_type: VariableType::Null, + }, + implementation: Rc::new(|_| { + Err(anyhow::anyhow!("自定义函数未找到")) + }), + }) + }, + ) + } +} diff --git a/core/expression/src/functions/mod.rs b/core/expression/src/functions/mod.rs index c583821e..5b33029e 100644 --- a/core/expression/src/functions/mod.rs +++ b/core/expression/src/functions/mod.rs @@ -3,24 +3,29 @@ pub use crate::functions::defs::FunctionTypecheck; pub use crate::functions::deprecated::DeprecatedFunction; pub use crate::functions::internal::InternalFunction; pub use crate::functions::method::{MethodKind, MethodRegistry}; +use crate::functions::mf_function::MfFunction; +pub use crate::functions::state_guard::{StateGuard, with_state_async}; pub use crate::functions::registry::FunctionRegistry; use std::fmt::Display; use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr}; -pub(crate) mod arguments; +pub mod arguments; mod date_method; pub(crate) mod defs; mod deprecated; pub(crate) mod internal; mod method; pub(crate) mod registry; +pub mod mf_function; +pub mod state_guard; #[derive(Debug, PartialEq, Eq, Clone)] pub enum FunctionKind { Internal(InternalFunction), Deprecated(DeprecatedFunction), Closure(ClosureFunction), + Mf(MfFunction), } impl TryFrom<&str> for FunctionKind { @@ -31,6 +36,7 @@ impl TryFrom<&str> for FunctionKind { .map(FunctionKind::Internal) .or_else(|_| DeprecatedFunction::try_from(value).map(FunctionKind::Deprecated)) .or_else(|_| ClosureFunction::try_from(value).map(FunctionKind::Closure)) + .or_else(|_| MfFunction::try_from(value).map(FunctionKind::Mf)) } } @@ -40,6 +46,7 @@ impl Display for FunctionKind { FunctionKind::Internal(i) => write!(f, "{i}"), FunctionKind::Deprecated(d) => write!(f, "{d}"), FunctionKind::Closure(c) => write!(f, "{c}"), + FunctionKind::Mf(m) => write!(f, "{m}"), } } } diff --git a/core/expression/src/functions/registry.rs b/core/expression/src/functions/registry.rs index a17c3645..da75f8de 100644 --- a/core/expression/src/functions/registry.rs +++ b/core/expression/src/functions/registry.rs @@ -1,4 +1,5 @@ use crate::functions::defs::FunctionDefinition; +use crate::functions::mf_function::MfFunctionRegistry; use crate::functions::{DeprecatedFunction, FunctionKind, InternalFunction}; use nohash_hasher::{BuildNoHashHasher, IsEnabled}; use std::cell::RefCell; @@ -33,6 +34,9 @@ impl FunctionRegistry { Self::INSTANCE.with_borrow(|i| i.deprecated_functions.get(&deprecated).cloned()) } FunctionKind::Closure(_) => None, + FunctionKind::Mf(mf) => { + MfFunctionRegistry::get_definition(&mf.name) + }, } } diff --git a/core/expression/src/functions/state_guard.rs b/core/expression/src/functions/state_guard.rs new file mode 100644 index 00000000..da25f736 --- /dev/null +++ b/core/expression/src/functions/state_guard.rs @@ -0,0 +1,184 @@ +//! State 守卫模块 +//! +//! 提供 RAII 模式的 State 管理,确保异常安全 + +use std::sync::Arc; +use super::mf_function::MfFunctionRegistry; +use std::marker::PhantomData; + +/// State 守卫,使用 RAII 模式自动管理 State 的设置和清理 +/// +/// 当 StateGuard 被创建时,会自动设置当前线程的 State 上下文 +/// 当 StateGuard 被丢弃时(包括异常情况),会自动清理 State 上下文 +/// +/// # 示例 +/// ```rust,ignore +/// use std::sync::Arc; +/// use mf_state::State; +/// use mf_rules_expression::functions::StateGuard; +/// +/// // 创建 State +/// let state = Arc::new(State::default()); +/// +/// { +/// // 设置 State 上下文 +/// let _guard = StateGuard::new(state); +/// +/// // 在这个作用域内,自定义函数可以访问 State +/// // 即使发生 panic,State 也会被正确清理 +/// +/// } // 这里 StateGuard 被自动丢弃,State 上下文被清理 +/// ``` +pub struct StateGuard { + _private: PhantomData, +} + +impl StateGuard { + /// 创建新的 State 守卫 + /// + /// # 参数 + /// * `state` - 要设置的 State 对象 + /// + /// # 返回值 + /// 返回 StateGuard 实例,当其被丢弃时会自动清理 State + pub fn new(state: Arc) -> Self { + MfFunctionRegistry::set_current_state(Some(state)); + Self { _private: PhantomData } + } + + /// 创建空的 State 守卫(用于清理已有的 State) + /// + /// # 返回值 + /// 返回 StateGuard 实例,会立即清理当前 State 并在丢弃时保持清理状态 + pub fn empty() -> Self { + MfFunctionRegistry::clear_current_state(); + Self { _private: PhantomData } + } + + /// 获取当前是否有活跃的 State + /// + /// # 返回值 + /// * `true` - 当前线程有活跃的 State + /// * `false` - 当前线程没有 State + pub fn has_active_state() -> bool { + MfFunctionRegistry::has_current_state() + } +} + +impl Drop for StateGuard { + /// 自动清理 State 上下文 + /// + /// 当 StateGuard 被丢弃时(正常情况或异常情况), + /// 会自动清理当前线程的 State 上下文 + fn drop(&mut self) { + MfFunctionRegistry::clear_current_state(); + } +} + +/// 便利宏,用于在指定作用域内设置 State +/// +/// # 示例 +/// ```rust,ignore +/// use mf_rules_expression::with_state; +/// +/// let state = Arc::new(State::default()); +/// +/// with_state!(state => { +/// // 在这个块内,State 是活跃的 +/// }); +/// // State 在这里已经被清理 +/// ``` +#[macro_export] +macro_rules! with_state { + ($state:expr => $block:block) => {{ + let _guard = $crate::functions::StateGuard::new($state); + $block + }}; +} + +/// 异步版本的 State 守卫便利函数 +/// +/// # 参数 +/// * `state` - 要设置的 State 对象 +/// * `future` - 要在 State 上下文中执行的异步操作 +/// +/// # 返回值 +/// 返回异步操作的结果 +/// +/// # 示例 +/// ```rust,ignore +/// use mf_rules_expression::functions::with_state_async; +/// +/// let state = Arc::new(State::default()); +/// +/// let result = with_state_async(state, async { +/// // aync block +/// }).await; +/// ``` +pub async fn with_state_async( + state: Arc, + future: F, +) -> T +where + S: Send + Sync + 'static, + F: FnOnce() -> Fut, + Fut: std::future::Future, +{ + let _guard = StateGuard::new(state); + future().await +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + // A dummy struct for testing purposes + struct DummyState; + + #[test] + fn test_state_guard_basic() { + // 初始状态应该没有 State + assert!(!StateGuard::::has_active_state()); + + { + // 创建一个模拟的 State + let state = Arc::new(DummyState); + let _guard = StateGuard::new(state); + + // 在这个作用域内应该有活跃的 State + assert!(StateGuard::::has_active_state()); + } + + // 离开作用域后,State 应该被清理 + assert!(!StateGuard::::has_active_state()); + } + + #[test] + fn test_state_guard_panic_safety() { + assert!(!StateGuard::::has_active_state()); + + let result = std::panic::catch_unwind(|| { + let state = Arc::new(DummyState); + let _guard = StateGuard::new(state); + + // 模拟 panic + panic!("测试 panic 安全性"); + }); + + // 即使发生了 panic,State 也应该被正确清理 + assert!(!StateGuard::::has_active_state()); + assert!(result.is_err()); + } + + #[test] + fn test_empty_guard() { + let state = Arc::new(DummyState); + let _guard = StateGuard::new(state); + assert!(StateGuard::::has_active_state()); + + // 创建空守卫应该立即清理 State + let _guard_empty = StateGuard::::empty(); + assert!(!StateGuard::::has_active_state()); + } +} diff --git a/core/expression/src/intellisense/types/provider.rs b/core/expression/src/intellisense/types/provider.rs index bcea1f4e..69f27588 100644 --- a/core/expression/src/intellisense/types/provider.rs +++ b/core/expression/src/intellisense/types/provider.rs @@ -403,7 +403,7 @@ impl TypesProvider { } match kind { - FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => { + FunctionKind::Internal(_) | FunctionKind::Deprecated(_) | FunctionKind::Mf(_) => { let Some(def) = FunctionRegistry::get_definition(kind) else { return V(VariableType::Any); }; diff --git a/core/expression/src/isolate.rs b/core/expression/src/isolate.rs index 8412f817..d67a4212 100644 --- a/core/expression/src/isolate.rs +++ b/core/expression/src/isolate.rs @@ -140,6 +140,17 @@ impl<'a> Isolate<'a> { Ok(result) } + pub fn run_standard_with_state( + &mut self, + source: &'a str, + state: Arc, + ) -> Result { + // 使用 StateGuard 自动管理 State 生命周期 + let _guard = crate::functions::StateGuard::new(state); + + // 运行表达式,即使发生异常,State 也会被正确清理 + self.run_standard(source) + } pub fn compile_unary(&mut self, source: &'a str) -> Result, IsolateError> { self.run_internal(source, ExpressionKind::Unary)?; let bytecode = self.compiler.get_bytecode().to_vec(); @@ -157,6 +168,17 @@ impl<'a> Isolate<'a> { result.as_bool().ok_or_else(|| IsolateError::ValueCastError) } + pub fn run_unary_with_state( + &mut self, + source: &'a str, + state: Arc, + ) -> Result { + // 使用 StateGuard 自动管理 State 生命周期 + let _guard = crate::functions::StateGuard::new(state); + + // 运行表达式,即使发生异常,State 也会被正确清理 + self.run_unary(source) + } } /// Errors which happen within isolate or during evaluation diff --git a/core/expression/src/parser/unary.rs b/core/expression/src/parser/unary.rs index c807167b..5ae3d2ea 100644 --- a/core/expression/src/parser/unary.rs +++ b/core/expression/src/parser/unary.rs @@ -344,65 +344,66 @@ impl From<&Node<'_>> for UnaryNodeBehaviour { }, Node::FunctionCall { kind, .. } => match kind { FunctionKind::Internal(i) => match i { - InternalFunction::Len => CompareWithReference(Equal), - InternalFunction::Upper => CompareWithReference(Equal), - InternalFunction::Lower => CompareWithReference(Equal), - InternalFunction::Trim => CompareWithReference(Equal), - InternalFunction::Abs => CompareWithReference(Equal), - InternalFunction::Sum => CompareWithReference(Equal), - InternalFunction::Avg => CompareWithReference(Equal), - InternalFunction::Min => CompareWithReference(Equal), - InternalFunction::Max => CompareWithReference(Equal), - InternalFunction::Rand => CompareWithReference(Equal), - InternalFunction::Median => CompareWithReference(Equal), - InternalFunction::Mode => CompareWithReference(Equal), - InternalFunction::Floor => CompareWithReference(Equal), - InternalFunction::Ceil => CompareWithReference(Equal), - InternalFunction::Round => CompareWithReference(Equal), - InternalFunction::Trunc => CompareWithReference(Equal), - InternalFunction::String => CompareWithReference(Equal), - InternalFunction::Number => CompareWithReference(Equal), - InternalFunction::Bool => CompareWithReference(Equal), - InternalFunction::Flatten => CompareWithReference(In), - InternalFunction::Extract => CompareWithReference(In), - InternalFunction::Contains => AsBoolean, - InternalFunction::StartsWith => AsBoolean, - InternalFunction::EndsWith => AsBoolean, - InternalFunction::Matches => AsBoolean, - InternalFunction::FuzzyMatch => CompareWithReference(Equal), - InternalFunction::Split => CompareWithReference(In), - InternalFunction::IsNumeric => AsBoolean, - InternalFunction::Keys => CompareWithReference(In), - InternalFunction::Values => CompareWithReference(In), - InternalFunction::Type => CompareWithReference(Equal), - InternalFunction::Date => CompareWithReference(Equal), - }, + InternalFunction::Len => CompareWithReference(Equal), + InternalFunction::Upper => CompareWithReference(Equal), + InternalFunction::Lower => CompareWithReference(Equal), + InternalFunction::Trim => CompareWithReference(Equal), + InternalFunction::Abs => CompareWithReference(Equal), + InternalFunction::Sum => CompareWithReference(Equal), + InternalFunction::Avg => CompareWithReference(Equal), + InternalFunction::Min => CompareWithReference(Equal), + InternalFunction::Max => CompareWithReference(Equal), + InternalFunction::Rand => CompareWithReference(Equal), + InternalFunction::Median => CompareWithReference(Equal), + InternalFunction::Mode => CompareWithReference(Equal), + InternalFunction::Floor => CompareWithReference(Equal), + InternalFunction::Ceil => CompareWithReference(Equal), + InternalFunction::Round => CompareWithReference(Equal), + InternalFunction::Trunc => CompareWithReference(Equal), + InternalFunction::String => CompareWithReference(Equal), + InternalFunction::Number => CompareWithReference(Equal), + InternalFunction::Bool => CompareWithReference(Equal), + InternalFunction::Flatten => CompareWithReference(In), + InternalFunction::Extract => CompareWithReference(In), + InternalFunction::Contains => AsBoolean, + InternalFunction::StartsWith => AsBoolean, + InternalFunction::EndsWith => AsBoolean, + InternalFunction::Matches => AsBoolean, + InternalFunction::FuzzyMatch => CompareWithReference(Equal), + InternalFunction::Split => CompareWithReference(In), + InternalFunction::IsNumeric => AsBoolean, + InternalFunction::Keys => CompareWithReference(In), + InternalFunction::Values => CompareWithReference(In), + InternalFunction::Type => CompareWithReference(Equal), + InternalFunction::Date => CompareWithReference(Equal), + }, FunctionKind::Deprecated(d) => match d { - DeprecatedFunction::Date => CompareWithReference(Equal), - DeprecatedFunction::Time => CompareWithReference(Equal), - DeprecatedFunction::Duration => CompareWithReference(Equal), - DeprecatedFunction::Year => CompareWithReference(Equal), - DeprecatedFunction::DayOfWeek => CompareWithReference(Equal), - DeprecatedFunction::DayOfMonth => CompareWithReference(Equal), - DeprecatedFunction::DayOfYear => CompareWithReference(Equal), - DeprecatedFunction::WeekOfYear => CompareWithReference(Equal), - DeprecatedFunction::MonthOfYear => CompareWithReference(Equal), - DeprecatedFunction::MonthString => CompareWithReference(Equal), - DeprecatedFunction::DateString => CompareWithReference(Equal), - DeprecatedFunction::WeekdayString => CompareWithReference(Equal), - DeprecatedFunction::StartOf => CompareWithReference(Equal), - DeprecatedFunction::EndOf => CompareWithReference(Equal), - }, + DeprecatedFunction::Date => CompareWithReference(Equal), + DeprecatedFunction::Time => CompareWithReference(Equal), + DeprecatedFunction::Duration => CompareWithReference(Equal), + DeprecatedFunction::Year => CompareWithReference(Equal), + DeprecatedFunction::DayOfWeek => CompareWithReference(Equal), + DeprecatedFunction::DayOfMonth => CompareWithReference(Equal), + DeprecatedFunction::DayOfYear => CompareWithReference(Equal), + DeprecatedFunction::WeekOfYear => CompareWithReference(Equal), + DeprecatedFunction::MonthOfYear => CompareWithReference(Equal), + DeprecatedFunction::MonthString => CompareWithReference(Equal), + DeprecatedFunction::DateString => CompareWithReference(Equal), + DeprecatedFunction::WeekdayString => CompareWithReference(Equal), + DeprecatedFunction::StartOf => CompareWithReference(Equal), + DeprecatedFunction::EndOf => CompareWithReference(Equal), + }, FunctionKind::Closure(c) => match c { - ClosureFunction::All => AsBoolean, - ClosureFunction::Some => AsBoolean, - ClosureFunction::None => AsBoolean, - ClosureFunction::One => AsBoolean, - ClosureFunction::Filter => CompareWithReference(In), - ClosureFunction::Map => CompareWithReference(In), - ClosureFunction::FlatMap => CompareWithReference(In), - ClosureFunction::Count => CompareWithReference(Equal), - }, + ClosureFunction::All => AsBoolean, + ClosureFunction::Some => AsBoolean, + ClosureFunction::None => AsBoolean, + ClosureFunction::One => AsBoolean, + ClosureFunction::Filter => CompareWithReference(In), + ClosureFunction::Map => CompareWithReference(In), + ClosureFunction::FlatMap => CompareWithReference(In), + ClosureFunction::Count => CompareWithReference(Equal), + }, + FunctionKind::Mf(_) => CompareWithReference(Equal), }, Node::MethodCall { kind, .. } => match kind { MethodKind::DateMethod(dm) => match dm {