Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion spectec/src/exe-spectec/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type pass =
| AliasDemut
| ImproveIds
| Ite
| PatSimp

(* This list declares the intended order of passes.

Expand All @@ -44,6 +45,7 @@ let all_passes = [
Uncaseremoval;
Sideconditions;
SubExpansion;
PatSimp;
Sub;
AliasDemut;
ImproveIds
Expand Down Expand Up @@ -112,6 +114,7 @@ let pass_flag = function
| Uncaseremoval -> "uncase-removal"
| ImproveIds -> "improve-ids"
| Ite -> "ite"
| PatSimp -> "pattern-simp"

let pass_desc = function
| Sub -> "Synthesize explicit subtype coercions"
Expand All @@ -126,6 +129,7 @@ let pass_desc = function
| AliasDemut -> "Lifts type aliases out of mutual groups"
| ImproveIds -> "Disambiguates ids used from each other"
| Ite -> "If-then-else introduction"
| PatSimp -> "Simplifies non-linear and definite iteration patterns"


let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function
Expand All @@ -141,7 +145,7 @@ let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function
| AliasDemut -> Middlend.AliasDemut.transform
| ImproveIds -> Middlend.Improveids.transform
| Ite -> Middlend.Ite.transform

| PatSimp -> Middlend.PatSimp.transform

(* Argument parsing *)

Expand Down
15 changes: 5 additions & 10 deletions spectec/src/il/walk.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ type transformer = {
transform_def_id: id -> id;
transform_gram_id: id -> id;

filter_exp : exp -> exp option
(* Adjusting traversal *)
transform_types_of_exp : bool
}

let id = Fun.id
Expand All @@ -55,7 +56,7 @@ let base_transformer = {
transform_def_id = id;
transform_gram_id = id;

filter_exp = op_id
transform_types_of_exp = true
}

let rec transform_typ t typ =
Expand All @@ -72,7 +73,6 @@ let rec transform_typ t typ =

and transform_exp t e =
let f = t.transform_exp in
let g = t.filter_exp in
let t_exp = transform_exp t in
let it =
match e.it with
Expand Down Expand Up @@ -107,13 +107,8 @@ and transform_exp t e =
| SubE (e1, _t1, t2) -> SubE (t_exp e1, _t1, t2)
| IfE (e1, e2, e3) -> IfE (t_exp e1, t_exp e2, t_exp e3)
in

let e' =
match g {e with it; note = transform_typ t e.note } with
| Some e' -> f e'
| None -> e
in
f e'
let typ' = if t.transform_types_of_exp then transform_typ t e.note else e.note in
f { e with it; note = typ' }

and transform_iter t iter =
match iter with
Expand Down
1 change: 1 addition & 0 deletions spectec/src/middlend/dune
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
aliasDemut
improveids
ite
patSimp
)
)
173 changes: 173 additions & 0 deletions spectec/src/middlend/patSimp.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
(*
This pass simplifies definite iteration and non-linear patterns
by utilizing premises.

It achieves this through the following steps:
- For non-linear patterns:
* For each clause, we traverse the arguments and keep track
of all variables in expressions. If a variable appears more
than once, we generate a fresh version of the variable and
keep it for later.
* Once we have traversed the entire argument list, we use
the variables tracked to generate new quantifiers and equality
premises.
- For definite iteration:
* For each clause, we traverse the arguments, and collect
all variables used for definite iteration (i.e. the e in ListN e _ )
and the respective lists being iterated.
* Using the collected variables, we iterate through the list to create
the equality premises.


For example (for non-linear pattern), take the function:

def $find(nat, nat* ) : bool
def $find(n, eps ) = false
def $find(n, n n'* ) = true
def $find(n, n_1 n'* ) = $find( n, n'* )
Would be transformed as such:

def $find(nat : nat, nat* ) : bool
def $find{n : nat}(n, []) = false
def $find{n : nat, `n'*` : nat*, n#1 : nat}(n, [n#1] ++ n'*{n' <- `n'*`}) = true
-- if (n = n#1)
def $find{n : nat, n_1 : nat, `n'*` : nat*}(n, [n_1] ++ n'*{n' <- `n'*`}) = $find(n, n'*{n' <- `n'*`})

For definite iteration:

def $len( int* ) : nat
def $len(i^n) = n

to

def $len(int* ) : nat
def $len{n : nat, `i*` : int*}(i*{i <- `i*`}) = n
-- if (n = |`i*`|)

NOTE: Currently does not work with dependent types. As such, it depends on the type family removal pass and undep pass.
This is a todo for the future.
*)

open Il.Ast
open Il.Walk
open Util
open Source

module StringMap = Map.Make(String)

let (let*) = Option.bind

let create_eq_prem id typ id' =
let idexp = VarE id $$ id.at % typ in
let idexp' = VarE (id' $ id.at) $$ id.at % typ in
let exp = CmpE (`EqOp, `BoolT, idexp, idexp') $$ id.at % (BoolT $ id.at) in
IfPr exp $ id.at

let create_eq_prem_exp e e' =
let exp = CmpE (`EqOp, `BoolT, e, e') $$ e.at % (BoolT $ e.at) in
IfPr exp $ e.at

let create_iter_prem scope base_prem =
List.fold_left (fun prem iterexp ->
IterPr (prem, iterexp) $ prem.at
) base_prem scope

let t_exp varmap exp =
match exp.it with
| VarE id ->
let fresh_var = ref id.it in
varmap := StringMap.update id.it (fun opt ->
match opt with
| Some lst ->
fresh_var := Utils.generate_var (List.map fst (StringMap.bindings !varmap) @ lst) id.it;
Some (!fresh_var :: lst)
| None -> Some []
) !varmap;
{ exp with it = VarE (!fresh_var $ id.at) }
| _ -> exp

let rec c_exp scope exp =
match exp.it with
| IterE (_, (ListN (e'', _), eps)) -> ([e'', eps, scope], true)
| IterE (e, (iter, eps)) ->
let lst_cl = base_collector [] (@) in
let cl = { lst_cl with collect_exp = c_exp ((iter, eps) :: scope) } in
let def_lst = collect_exp cl e in
(def_lst, false)
| _ -> ([], true)

let t_exp2 exp =
match exp.it with
| IterE (e, (ListN _, eps)) -> { exp with it = IterE (e, (List, eps)) }
| _ -> exp

let t_typ2 typ =
match typ.it with
| IterT (t, ListN _) -> { typ with it = IterT (t, List) }
| _ -> typ

let handle_non_linear clause =
let DefD (quants, args, exp, prs) = clause.it in
let varmap = ref StringMap.empty in
let tf = { base_transformer with transform_exp = t_exp varmap; transform_types_of_exp = false } in
let args' = List.map (transform_arg tf) args in
let new_quants, new_prs = List.filter_map (fun q -> match q.it with
| ExpP (id, typ) ->
let* ts = StringMap.find_opt id.it !varmap in
if ts = [] then None else
let q' = List.map (fun id' -> ExpP (id' $ id.at, typ) $ id.at) ts in
let prs'= List.map (create_eq_prem id typ) ts in
Some (q', prs')
| _ -> None
) quants |> List.split
in

{ clause with it = DefD (quants @ (List.concat new_quants), args', exp, prs @ (List.concat new_prs)) }

let handle_definite_iter clause =
let DefD (quants, args, exp, prs) = clause.it in
let lst_cl = base_collector [] (@) in
let cl = { lst_cl with collect_exp = c_exp [] } in
let tf = { base_transformer with transform_exp = t_exp2; transform_typ = t_typ2 } in

let def_lst = List.concat_map (collect_arg cl) args in
let new_prs = List.concat_map (fun (n, eps, scope) ->
let lene e = LenE e $$ e.at % (NumT `NatT $ e.at) in
List.map (fun (_, e) ->
let base_prem = create_eq_prem_exp n (lene e) in
create_iter_prem scope base_prem) eps
) def_lst
in

{ clause with it = DefD (quants, List.map (transform_arg tf) args, exp, prs @ new_prs) }

let handle_definite_iter_rel rule =
let RuleD (id, quants, mixop, exp, prs) = rule.it in
let lst_cl = base_collector [] (@) in
let cl = { lst_cl with collect_exp = c_exp [] } in

let def_lst = collect_exp cl exp @ List.concat_map (collect_prem cl) prs in
let def_lst_uniq = Lib.List.nub (fun (n1, eps1, _) (n2, eps2, _) ->
Il.Eq.eq_exp n1 n2 && List.length eps1 = List.length eps2 &&
List.for_all2 (fun (id1, e1) (id2, e2) -> id1.it = id2.it && Il.Eq.eq_exp e1 e2) eps1 eps2
) def_lst in
let new_prs = List.concat_map (fun (n, eps, scope) ->
let lene e = LenE e $$ e.at % (NumT `NatT $ e.at) in
List.map (fun (_, e) ->
let base_prem = create_eq_prem_exp n (lene e) in
create_iter_prem scope base_prem) eps
) def_lst_uniq
in

{ rule with it = RuleD (id, quants, mixop, exp, prs @ new_prs) }

let rec t_def def =
match def.it with
| DecD (id, params, rt, clauses) -> { def with it = DecD (id, params, rt,
clauses |> List.map handle_definite_iter |> List.map (handle_non_linear)) }
| RelD (id, qs, mixop, typ, rules) ->
{ def with it = RelD (id, qs, mixop, typ, List.map handle_definite_iter_rel rules) }
| RecD defs -> { def with it = RecD (List.map t_def defs) }
| _ -> def

let transform il = List.map t_def il
1 change: 1 addition & 0 deletions spectec/src/middlend/patSimp.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val transform : Il.Ast.script -> Il.Ast.script
25 changes: 24 additions & 1 deletion spectec/test-middlend/test.spectec
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,27 @@ def $t_totalize3(n) = $($t_totalize(n) + $t_totalize2($(n + 10)))
;; def $t_totalize3(n) = $($t_totalize(n) + $t_totalize2($t_totalize(n)))
;;
;; def $t_totalize4(nat) : nat hint(partial)
;; def $t_totalize4(n) = $t_totalize($t_totalize(n))
;; def $t_totalize4(n) = $t_totalize($t_totalize(n))

;;
;; Pattern Simp testing
;;

syntax A = B nat

def $t_patsimp(nat, nat) : nat
def $t_patsimp(n, n) = n
def $t_patsimp(n, m) = $(n + m)

def $t_patsimp2(nat, nat, A) : nat
def $t_patsimp2(n, n, B n) = n
def $t_patsimp2(n, m, B m) = $(n + m)
def $t_patsimp2(n, m, B k) = $(n + m + k)

def $find(nat, nat*) : bool
def $find(n, eps) = false
def $find(n, n n'*) = true
def $find(n, n_1 n'*) = $find(n, n'*)

def $len(int*) : nat
def $len(i^n) = n
Loading