diff --git a/spectec/src/exe-spectec/main.ml b/spectec/src/exe-spectec/main.ml index ae86b09d9e..3a52c7a05a 100644 --- a/spectec/src/exe-spectec/main.ml +++ b/spectec/src/exe-spectec/main.ml @@ -27,6 +27,7 @@ type pass = | AliasDemut | ImproveIds | Ite + | PatSimp (* This list declares the intended order of passes. @@ -44,6 +45,7 @@ let all_passes = [ Uncaseremoval; Sideconditions; SubExpansion; + PatSimp; Sub; AliasDemut; ImproveIds @@ -112,6 +114,7 @@ let pass_flag = function | Uncaseremoval -> "uncase-removal" | ImproveIds -> "improve-ids" | Ite -> "ite" + | PatSimp -> "pattern-simp" let pass_desc = function | Sub -> "Synthesize explicit subtype coercions" @@ -126,6 +129,7 @@ let pass_desc = function | AliasDemut -> "Lifts type aliases out of mutual groups" | ImproveIds -> "Disambiguates ids used from each other" | Ite -> "If-then-else introduction" + | PatSimp -> "Simplifies non-linear and definite iteration patterns" let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function @@ -141,7 +145,7 @@ let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function | AliasDemut -> Middlend.AliasDemut.transform | ImproveIds -> Middlend.Improveids.transform | Ite -> Middlend.Ite.transform - + | PatSimp -> Middlend.PatSimp.transform (* Argument parsing *) diff --git a/spectec/src/il/walk.ml b/spectec/src/il/walk.ml index d8d1d3e014..4a471d0946 100644 --- a/spectec/src/il/walk.ml +++ b/spectec/src/il/walk.ml @@ -35,7 +35,8 @@ type transformer = { transform_def_id: id -> id; transform_gram_id: id -> id; - filter_exp : exp -> exp option + (* Adjusting traversal *) + transform_types_of_exp : bool } let id = Fun.id @@ -55,7 +56,7 @@ let base_transformer = { transform_def_id = id; transform_gram_id = id; - filter_exp = op_id + transform_types_of_exp = true } let rec transform_typ t typ = @@ -72,7 +73,6 @@ let rec transform_typ t typ = and transform_exp t e = let f = t.transform_exp in - let g = t.filter_exp in let t_exp = transform_exp t in let it = match e.it with @@ -107,13 +107,8 @@ and transform_exp t e = | SubE (e1, _t1, t2) -> SubE (t_exp e1, _t1, t2) | IfE (e1, e2, e3) -> IfE (t_exp e1, t_exp e2, t_exp e3) in - - let e' = - match g {e with it; note = transform_typ t e.note } with - | Some e' -> f e' - | None -> e - in - f e' + let typ' = if t.transform_types_of_exp then transform_typ t e.note else e.note in + f { e with it; note = typ' } and transform_iter t iter = match iter with diff --git a/spectec/src/middlend/dune b/spectec/src/middlend/dune index c2bc064072..31f3bb2553 100644 --- a/spectec/src/middlend/dune +++ b/spectec/src/middlend/dune @@ -15,5 +15,6 @@ aliasDemut improveids ite + patSimp ) ) diff --git a/spectec/src/middlend/patSimp.ml b/spectec/src/middlend/patSimp.ml new file mode 100644 index 0000000000..cf9d0066a5 --- /dev/null +++ b/spectec/src/middlend/patSimp.ml @@ -0,0 +1,173 @@ +(* + This pass simplifies definite iteration and non-linear patterns + by utilizing premises. + + It achieves this through the following steps: + - For non-linear patterns: + * For each clause, we traverse the arguments and keep track + of all variables in expressions. If a variable appears more + than once, we generate a fresh version of the variable and + keep it for later. + * Once we have traversed the entire argument list, we use + the variables tracked to generate new quantifiers and equality + premises. + - For definite iteration: + * For each clause, we traverse the arguments, and collect + all variables used for definite iteration (i.e. the e in ListN e _ ) + and the respective lists being iterated. + * Using the collected variables, we iterate through the list to create + the equality premises. + + + For example (for non-linear pattern), take the function: + + def $find(nat, nat* ) : bool + def $find(n, eps ) = false + def $find(n, n n'* ) = true + def $find(n, n_1 n'* ) = $find( n, n'* ) + Would be transformed as such: + + def $find(nat : nat, nat* ) : bool + def $find{n : nat}(n, []) = false + def $find{n : nat, `n'*` : nat*, n#1 : nat}(n, [n#1] ++ n'*{n' <- `n'*`}) = true + -- if (n = n#1) + def $find{n : nat, n_1 : nat, `n'*` : nat*}(n, [n_1] ++ n'*{n' <- `n'*`}) = $find(n, n'*{n' <- `n'*`}) + + For definite iteration: + + def $len( int* ) : nat + def $len(i^n) = n + + to + + def $len(int* ) : nat + def $len{n : nat, `i*` : int*}(i*{i <- `i*`}) = n + -- if (n = |`i*`|) + +NOTE: Currently does not work with dependent types. As such, it depends on the type family removal pass and undep pass. +This is a todo for the future. +*) + +open Il.Ast +open Il.Walk +open Util +open Source + +module StringMap = Map.Make(String) + +let (let*) = Option.bind + +let create_eq_prem id typ id' = + let idexp = VarE id $$ id.at % typ in + let idexp' = VarE (id' $ id.at) $$ id.at % typ in + let exp = CmpE (`EqOp, `BoolT, idexp, idexp') $$ id.at % (BoolT $ id.at) in + IfPr exp $ id.at + +let create_eq_prem_exp e e' = + let exp = CmpE (`EqOp, `BoolT, e, e') $$ e.at % (BoolT $ e.at) in + IfPr exp $ e.at + +let create_iter_prem scope base_prem = + List.fold_left (fun prem iterexp -> + IterPr (prem, iterexp) $ prem.at + ) base_prem scope + +let t_exp varmap exp = + match exp.it with + | VarE id -> + let fresh_var = ref id.it in + varmap := StringMap.update id.it (fun opt -> + match opt with + | Some lst -> + fresh_var := Utils.generate_var (List.map fst (StringMap.bindings !varmap) @ lst) id.it; + Some (!fresh_var :: lst) + | None -> Some [] + ) !varmap; + { exp with it = VarE (!fresh_var $ id.at) } + | _ -> exp + +let rec c_exp scope exp = + match exp.it with + | IterE (_, (ListN (e'', _), eps)) -> ([e'', eps, scope], true) + | IterE (e, (iter, eps)) -> + let lst_cl = base_collector [] (@) in + let cl = { lst_cl with collect_exp = c_exp ((iter, eps) :: scope) } in + let def_lst = collect_exp cl e in + (def_lst, false) + | _ -> ([], true) + +let t_exp2 exp = + match exp.it with + | IterE (e, (ListN _, eps)) -> { exp with it = IterE (e, (List, eps)) } + | _ -> exp + +let t_typ2 typ = + match typ.it with + | IterT (t, ListN _) -> { typ with it = IterT (t, List) } + | _ -> typ + +let handle_non_linear clause = + let DefD (quants, args, exp, prs) = clause.it in + let varmap = ref StringMap.empty in + let tf = { base_transformer with transform_exp = t_exp varmap; transform_types_of_exp = false } in + let args' = List.map (transform_arg tf) args in + let new_quants, new_prs = List.filter_map (fun q -> match q.it with + | ExpP (id, typ) -> + let* ts = StringMap.find_opt id.it !varmap in + if ts = [] then None else + let q' = List.map (fun id' -> ExpP (id' $ id.at, typ) $ id.at) ts in + let prs'= List.map (create_eq_prem id typ) ts in + Some (q', prs') + | _ -> None + ) quants |> List.split + in + + { clause with it = DefD (quants @ (List.concat new_quants), args', exp, prs @ (List.concat new_prs)) } + +let handle_definite_iter clause = + let DefD (quants, args, exp, prs) = clause.it in + let lst_cl = base_collector [] (@) in + let cl = { lst_cl with collect_exp = c_exp [] } in + let tf = { base_transformer with transform_exp = t_exp2; transform_typ = t_typ2 } in + + let def_lst = List.concat_map (collect_arg cl) args in + let new_prs = List.concat_map (fun (n, eps, scope) -> + let lene e = LenE e $$ e.at % (NumT `NatT $ e.at) in + List.map (fun (_, e) -> + let base_prem = create_eq_prem_exp n (lene e) in + create_iter_prem scope base_prem) eps + ) def_lst + in + + { clause with it = DefD (quants, List.map (transform_arg tf) args, exp, prs @ new_prs) } + +let handle_definite_iter_rel rule = + let RuleD (id, quants, mixop, exp, prs) = rule.it in + let lst_cl = base_collector [] (@) in + let cl = { lst_cl with collect_exp = c_exp [] } in + + let def_lst = collect_exp cl exp @ List.concat_map (collect_prem cl) prs in + let def_lst_uniq = Lib.List.nub (fun (n1, eps1, _) (n2, eps2, _) -> + Il.Eq.eq_exp n1 n2 && List.length eps1 = List.length eps2 && + List.for_all2 (fun (id1, e1) (id2, e2) -> id1.it = id2.it && Il.Eq.eq_exp e1 e2) eps1 eps2 + ) def_lst in + let new_prs = List.concat_map (fun (n, eps, scope) -> + let lene e = LenE e $$ e.at % (NumT `NatT $ e.at) in + List.map (fun (_, e) -> + let base_prem = create_eq_prem_exp n (lene e) in + create_iter_prem scope base_prem) eps + ) def_lst_uniq + in + + { rule with it = RuleD (id, quants, mixop, exp, prs @ new_prs) } + +let rec t_def def = + match def.it with + | DecD (id, params, rt, clauses) -> { def with it = DecD (id, params, rt, + clauses |> List.map handle_definite_iter |> List.map (handle_non_linear)) } + | RelD (id, qs, mixop, typ, rules) -> + { def with it = RelD (id, qs, mixop, typ, List.map handle_definite_iter_rel rules) } + | RecD defs -> { def with it = RecD (List.map t_def defs) } + | _ -> def + +let transform il = List.map t_def il diff --git a/spectec/src/middlend/patSimp.mli b/spectec/src/middlend/patSimp.mli new file mode 100644 index 0000000000..64d020ff9d --- /dev/null +++ b/spectec/src/middlend/patSimp.mli @@ -0,0 +1 @@ +val transform : Il.Ast.script -> Il.Ast.script \ No newline at end of file diff --git a/spectec/test-middlend/test.spectec b/spectec/test-middlend/test.spectec index 933b224916..82d8a46852 100644 --- a/spectec/test-middlend/test.spectec +++ b/spectec/test-middlend/test.spectec @@ -33,4 +33,27 @@ def $t_totalize3(n) = $($t_totalize(n) + $t_totalize2($(n + 10))) ;; def $t_totalize3(n) = $($t_totalize(n) + $t_totalize2($t_totalize(n))) ;; ;; def $t_totalize4(nat) : nat hint(partial) -;; def $t_totalize4(n) = $t_totalize($t_totalize(n)) \ No newline at end of file +;; def $t_totalize4(n) = $t_totalize($t_totalize(n)) + +;; +;; Pattern Simp testing +;; + +syntax A = B nat + +def $t_patsimp(nat, nat) : nat +def $t_patsimp(n, n) = n +def $t_patsimp(n, m) = $(n + m) + +def $t_patsimp2(nat, nat, A) : nat +def $t_patsimp2(n, n, B n) = n +def $t_patsimp2(n, m, B m) = $(n + m) +def $t_patsimp2(n, m, B k) = $(n + m + k) + +def $find(nat, nat*) : bool +def $find(n, eps) = false +def $find(n, n n'*) = true +def $find(n, n_1 n'*) = $find(n, n'*) + +def $len(int*) : nat +def $len(i^n) = n \ No newline at end of file