diff --git a/eggplant-transpiler/src/eggplant.rs b/eggplant-transpiler/src/eggplant.rs index c861626..647e341 100644 --- a/eggplant-transpiler/src/eggplant.rs +++ b/eggplant-transpiler/src/eggplant.rs @@ -466,26 +466,34 @@ impl EggplantCodeGenerator { self.add_line("}"); } EggplantCommand::RunSchedule { schedules } => { - let schedule_program = schedules - .iter() - .map(ToString::to_string) - .collect::>() - .join(" "); - let command = format!("(run-schedule {schedule_program})"); - self.add_line("let outputs = {"); - self.indent(); - self.add_line("let mut egraph = MyTx::sgl().egraph.lock().unwrap();"); - self.add_line(&format!( - "egraph.parse_and_run_program(None, {:?}).unwrap()", - command - )); - self.dedent(); - self.add_line("};"); - self.add_line("for output in outputs {"); - self.indent(); - self.add_line("print!(\"{}\", output);"); - self.dedent(); - self.add_line("}"); + if schedules.iter().any(schedule_has_until) { + let schedule_program = schedules + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + let command = format!("(run-schedule {schedule_program})"); + self.add_line("let outputs = {"); + self.indent(); + self.add_line("let mut egraph = MyTx::sgl().egraph.lock().unwrap();"); + self.add_line(&format!( + "egraph.parse_and_run_program(None, {:?}).unwrap()", + command + )); + self.dedent(); + self.add_line("};"); + self.add_line("for output in outputs {"); + self.indent(); + self.add_line("print!(\"{}\", output);"); + self.dedent(); + self.add_line("}"); + } else { + self.add_line(&format!( + "let schedule = {};", + schedule_items_to_rust_expr(schedules) + )); + self.add_line("let _report = MyTx::run_schedule(schedule);"); + } } EggplantCommand::Assert { expr, expected } => { self.add_line(&format!( @@ -1333,6 +1341,99 @@ fn same_normalized_identifier(left: &str, right: &str) -> bool { normalize_identifier(left) == normalize_identifier(right) } +fn schedule_has_until(schedule: &Schedule) -> bool { + match schedule { + Schedule::Run { until, .. } => until.is_some(), + Schedule::Named(_) => false, + Schedule::Seq(items) | Schedule::Saturate(items) => items.iter().any(schedule_has_until), + Schedule::Repeat(_, inner) => schedule_has_until(inner), + } +} + +fn schedule_items_to_rust_expr(schedules: &[Schedule]) -> String { + match schedules { + [] => "RunSchedule::builder().build()".to_string(), + [schedule] => schedule_to_rust_expr(schedule), + _ => { + let mut expr = "RunSchedule::builder()".to_string(); + for schedule in schedules { + expr.push_str(&format!(".then({})", schedule_to_rust_expr(schedule))); + } + expr.push_str(".build()"); + expr + } + } +} + +fn schedule_to_rust_expr(schedule: &Schedule) -> String { + match schedule { + Schedule::Run { + ruleset, + limit, + until: None, + } => { + let ruleset = schedule_ruleset_expr(ruleset.as_deref()); + let run = format!("RunSchedule::builder().run({ruleset}).build()"); + match limit { + Some(limit) => { + format!( + "RunSchedule::builder().repeat({limit}, |schedule| schedule.then({run})).build()" + ) + } + None => run, + } + } + Schedule::Run { .. } => { + unreachable!("run-schedule with :until must use raw egglog bridge") + } + Schedule::Named(name) => { + format!( + "RunSchedule::builder().run({}).build()", + normalize_ruleset_name(name) + ) + } + Schedule::Seq(items) => schedule_items_to_rust_expr(items), + Schedule::Saturate(items) => match items.as_slice() { + [single] => { + if let Some(ruleset) = direct_schedule_ruleset_expr(single) { + format!("RunSchedule::builder().saturate({ruleset}).build()") + } else { + format!( + "RunSchedule::builder().saturate_schedule({}).build()", + schedule_to_rust_expr(single) + ) + } + } + _ => format!( + "RunSchedule::builder().saturate_schedule({}).build()", + schedule_items_to_rust_expr(items) + ), + }, + Schedule::Repeat(times, inner) => { + format!( + "RunSchedule::builder().repeat({times}, |schedule| schedule.then({})).build()", + schedule_to_rust_expr(inner) + ) + } + } +} + +fn direct_schedule_ruleset_expr(schedule: &Schedule) -> Option { + match schedule { + Schedule::Named(name) => Some(normalize_ruleset_name(name)), + Schedule::Run { + ruleset, + limit: None, + until: None, + } => Some(schedule_ruleset_expr(ruleset.as_deref())), + _ => None, + } +} + +fn schedule_ruleset_expr(ruleset: Option<&str>) -> String { + normalize_ruleset_name(ruleset.unwrap_or("default")) +} + fn expr_type_name(expr: &Expr) -> String { match expr { Expr::Var(_, name) => normalize_identifier(name), @@ -4532,6 +4633,43 @@ mod tests { assert!(rust.contains("for output in outputs {")); } + #[test] + fn test_run_schedule_without_until_uses_builder_codegen() { + let program = r#" + (ruleset fast-analyses) + (ruleset subst) + (run-schedule + (repeat 2 + (saturate fast-analyses) + (run) + (saturate subst))) + "#; + + let mut parser = Parser::default(); + let commands = parser.get_program_from_string(None, program).unwrap(); + let rust = EggplantCodeGenerator::new() + .generate_rust(&convert_to_eggplant_with_source(&commands, None)); + + assert!( + rust.contains("RunSchedule::builder()"), + "generated Rust:\n{rust}" + ); + assert!(rust.contains(".repeat(2"), "generated Rust:\n{rust}"); + assert!( + rust.contains(".saturate(fast_analyses)"), + "generated Rust:\n{rust}" + ); + assert!( + rust.contains(".run(default_ruleset)"), + "generated Rust:\n{rust}" + ); + assert!(rust.contains(".saturate(subst)"), "generated Rust:\n{rust}"); + assert!( + !rust.contains("parse_and_run_program"), + "generated Rust:\n{rust}" + ); + } + #[test] fn test_include_conversion_and_codegen_uses_resolved_path() { let source_path = upstream_egglog_fixture("tests/include.egg"); diff --git a/src/prelude.rs b/src/prelude.rs index e4ab62a..f646f7d 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::wrap::sorts::vec::VecContainer; pub use crate::wrap::{ AsHandle, BaseVar, Commit, EBoostExtractConfig, EBoostLayeredConfig, EgglogNode, ExtractBackend, ExtractNodeSgl, ExtractSgl, FromBase, Insertable, LocateVersion, PEq, - PatRecSgl, QuerySlot, RenderedTemplateField, RuleRunnerSgl, RuleSetId, RunConfig, - RustsatExtractConfig, RxSgl, SingletonGetter, SlotVarID, SlottedPatRecSgl, ToDot, ToDotSgl, - TxCommit, TxCommitSgl, TxSgl, Value, extract_raw_with_backend, render_template_with_precedence, - render_variant_display, render_variant_typst, + PatRecSgl, QuerySlot, RenderedTemplateField, RuleRunnerSgl, RuleSetId, RunConfig, RunSchedule, + RunScheduleBuilder, RustsatExtractConfig, RxSgl, SingletonGetter, SlotVarID, SlottedPatRecSgl, + ToDot, ToDotSgl, TxCommit, TxCommitSgl, TxSgl, Value, extract_raw_with_backend, + render_template_with_precedence, render_variant_display, render_variant_typst, }; pub use dashmap; diff --git a/src/wrap/rule.rs b/src/wrap/rule.rs index 75a2695..2ff2da0 100644 --- a/src/wrap/rule.rs +++ b/src/wrap/rule.rs @@ -590,6 +590,11 @@ pub trait RuleRunner { ); fn new_ruleset(&self, rule_set: &'static str) -> RuleSetId; fn run_ruleset(&self, rule_set_id: RuleSetId, run_config: RunConfig) -> RunReport; + fn run_schedule(&self, schedule: impl Into) -> RunReport { + schedule + .into() + .execute_with(|ruleset| self.run_ruleset(ruleset, RunConfig::Once)) + } fn value(&self, node: &T) -> Value; } pub trait RuleRunnerSgl: WithPatRecSgl + NodeDropperSgl { @@ -619,6 +624,11 @@ pub trait RuleRunnerSgl: WithPatRecSgl + NodeDropperSgl { ); fn new_ruleset(rule_set: &'static str) -> RuleSetId; fn run_ruleset(rule_set_id: RuleSetId, run_config: RunConfig) -> RunReport; + fn run_schedule(schedule: impl Into) -> RunReport { + schedule + .into() + .execute_with(|ruleset| Self::run_ruleset(ruleset, RunConfig::Once)) + } fn value(node: &T) -> Value; } impl RuleRunnerSgl for T @@ -646,15 +656,166 @@ where } } -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RuleSetId(pub &'static str); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RunConfig { Sat, Times(u32), Once, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RunSchedule { + Run(RuleSetId), + Sequence(Vec), + Repeat { + times: u32, + schedule: Box, + }, + Saturate(Box), +} + +impl RunSchedule { + pub fn builder() -> RunScheduleBuilder { + RunScheduleBuilder::new() + } + + pub fn run(ruleset: RuleSetId) -> Self { + Self::Run(ruleset) + } + + pub fn from_ruleset_config(ruleset: RuleSetId, config: RunConfig) -> Self { + match config { + RunConfig::Once => Self::run(ruleset), + RunConfig::Times(times) => Self::repeat(times, Self::run(ruleset)), + RunConfig::Sat => Self::saturate(Self::run(ruleset)), + } + } + + pub fn seq(items: impl IntoIterator) -> Self { + Self::Sequence(items.into_iter().collect()) + } + + pub fn repeat(times: u32, schedule: impl Into) -> Self { + Self::Repeat { + times, + schedule: Box::new(schedule.into()), + } + } + + pub fn saturate(schedule: impl Into) -> Self { + Self::Saturate(Box::new(schedule.into())) + } + + pub fn execute_with(&self, mut run_once: impl FnMut(RuleSetId) -> RunReport) -> RunReport { + self.execute_with_ref(&mut run_once) + } + + fn execute_with_ref(&self, run_once: &mut impl FnMut(RuleSetId) -> RunReport) -> RunReport { + match self { + RunSchedule::Run(ruleset) => run_once(*ruleset), + RunSchedule::Sequence(schedules) => { + let mut report = RunReport::default(); + for schedule in schedules { + report.union(schedule.execute_with_ref(run_once)); + } + report + } + RunSchedule::Repeat { times, schedule } => { + let mut report = RunReport::default(); + for _ in 0..*times { + let iter_report = schedule.execute_with_ref(run_once); + let updated = iter_report.updated; + report.union(iter_report); + if !updated { + break; + } + } + report + } + RunSchedule::Saturate(schedule) => { + let mut report = RunReport::default(); + loop { + let iter_report = schedule.execute_with_ref(run_once); + let updated = iter_report.updated; + report.union(iter_report); + if !updated { + break report; + } + } + } + } + } +} + +impl From for RunSchedule { + fn from(ruleset: RuleSetId) -> Self { + Self::run(ruleset) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct RunScheduleBuilder { + schedules: Vec, +} + +impl RunScheduleBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn run(mut self, ruleset: RuleSetId) -> Self { + self.schedules.push(RunSchedule::run(ruleset)); + self + } + + pub fn run_with(mut self, ruleset: RuleSetId, config: RunConfig) -> Self { + self.schedules + .push(RunSchedule::from_ruleset_config(ruleset, config)); + self + } + + pub fn then(mut self, schedule: impl Into) -> Self { + self.schedules.push(schedule.into()); + self + } + + pub fn repeat( + mut self, + times: u32, + build: impl FnOnce(RunScheduleBuilder) -> RunScheduleBuilder, + ) -> Self { + self.schedules + .push(RunSchedule::repeat(times, build(Self::new()).build())); + self + } + + pub fn repeat_schedule(mut self, times: u32, schedule: impl Into) -> Self { + self.schedules.push(RunSchedule::repeat(times, schedule)); + self + } + + pub fn saturate(mut self, ruleset: RuleSetId) -> Self { + self.schedules.push(RunSchedule::saturate(ruleset)); + self + } + + pub fn saturate_schedule(mut self, schedule: impl Into) -> Self { + self.schedules.push(RunSchedule::saturate(schedule)); + self + } + + pub fn build(self) -> RunSchedule { + match self.schedules.len() { + 0 => RunSchedule::Sequence(Vec::new()), + 1 => self.schedules.into_iter().next().unwrap(), + _ => RunSchedule::Sequence(self.schedules), + } + } +} + pub struct FactsBuilder { table_facts: Vec, constraint_facts: Vec>, diff --git a/tests/schedule_builder.rs b/tests/schedule_builder.rs new file mode 100644 index 0000000..d29cbaf --- /dev/null +++ b/tests/schedule_builder.rs @@ -0,0 +1,71 @@ +#![allow(non_camel_case_types)] + +use eggplant::{prelude::*, tx_rx_vt_pr}; + +#[eggplant::func(output = i64, no_merge)] +struct sched_fib { + x: i64, +} + +#[eggplant::pat_vars] +struct FibStep { + x: i64, + x1: i64, + x2: i64, + f0: i64, + f1: i64, +} + +tx_rx_vt_pr!(ScheduleBuilderTx, ScheduleBuilderPatRec); + +#[test] +fn schedule_builder_runs_nested_repeat_schedule() { + let seed = ScheduleBuilderTx::new_ruleset("schedule_builder_seed"); + ScheduleBuilderTx::add_rule( + "schedule_builder_seed", + seed, + || { + #[eggplant::pat_vars_catch] + struct Unit {} + }, + |ctx, _pat| { + ctx.set_sched_fib(0, 0); + ctx.set_sched_fib(1, 1); + }, + ); + + let step = ScheduleBuilderTx::new_ruleset("schedule_builder_step"); + ScheduleBuilderTx::add_rule( + "schedule_builder_step", + step, + || { + let (x, x1, x2) = ( + sched_fib::x(), + sched_fib::x().named("x1"), + sched_fib::x().named("x2"), + ); + let x1_constraint = x1.handle().eq(&(x.handle() + (&1_i64).as_handle())); + let x2_constraint = x2.handle().eq(&(x.handle() + (&2_i64).as_handle())); + let f0 = sched_fib::query(&x); + let f1 = sched_fib::query(&x1); + FibStep::new(x, x1, x2, f0, f1) + .assert(x1_constraint) + .assert(x2_constraint) + }, + |ctx, pat| { + let x2 = ctx.devalue(pat.x2); + let f0 = ctx.devalue(pat.f0); + let f1 = ctx.devalue(pat.f1); + ctx.set_sched_fib(x2, f0 + f1); + }, + ); + + let schedule = RunSchedule::builder() + .run(seed) + .repeat(7, |schedule| schedule.run(step)) + .build(); + + ScheduleBuilderTx::run_schedule(schedule); + + assert_eq!(sched_fib::::get(&7), 13); +}