Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 91 additions & 27 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -21227,7 +21227,15 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 40,
"endColumn": 30,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 49,
"lineCount": 1
}
},
Expand Down Expand Up @@ -21477,23 +21485,31 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 20,
"endColumn": 43,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 20,
"endColumn": 52,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 20,
"endColumn": 73,
"endColumn": 82,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 60,
"endColumn": 72,
"startColumn": 69,
"endColumn": 81,
"lineCount": 1
}
},
Expand Down Expand Up @@ -53877,6 +53893,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 22,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand All @@ -53896,8 +53920,8 @@
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 42,
"endColumn": 46,
"startColumn": 37,
"endColumn": 47,
"lineCount": 1
}
},
Expand All @@ -53921,15 +53945,15 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 31,
"endColumn": 38,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 31,
"endColumn": 38,
"lineCount": 1
}
},
Expand All @@ -53949,6 +53973,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 57,
"endColumn": 62,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down Expand Up @@ -53997,6 +54029,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 37,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand All @@ -54013,6 +54053,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 37,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand All @@ -54033,15 +54081,15 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 34,
"endColumn": 41,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 29,
"endColumn": 44,
"endColumn": 51,
"lineCount": 1
}
},
Expand All @@ -54053,6 +54101,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 51,
"endColumn": 57,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand All @@ -54061,6 +54117,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 51,
"endColumn": 57,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down Expand Up @@ -54093,6 +54157,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 51,
"endColumn": 57,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down Expand Up @@ -90429,6 +90501,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 13,
"endColumn": 26,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down Expand Up @@ -101881,22 +101961,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 36,
"endColumn": 48,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 36,
"endColumn": 51,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
4 changes: 2 additions & 2 deletions loopy/isl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def simplify_pw_aff(pw_aff, context=None):
if i == j:
continue

if aff_i.gist(dom_j).is_equal(aff_j):
if aff_i.gist(dom_j).to_pw_aff().is_equal(aff_j):
# aff_i is sufficient to cover aff_j, eliminate aff_j
new_pieces = pieces[:]
if i < j:
Expand Down Expand Up @@ -895,7 +895,7 @@ def find_and_rename_dim(isl_obj, dt, old_name, new_name):

"""
return isl_obj.set_dim_name(
dt, isl_obj.find_dim_by_name(dt, old_name), new_name)
dt, isl_obj.to_set().find_dim_by_name(dt, old_name), new_name)

# }}}

Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def combine_domains(self, domains: Sequence[int]) -> isl.BasicSet:
dim_type.set,
result.dim(dim_type.set),
dim_type.param,
result.find_dim_by_name(dim_type.param, actual_iname),
result.to_set().find_dim_by_name(dim_type.param, actual_iname),
1)

return result
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _normalize_string_tag(tag):
if tag == "!streaming_store":
return UseStreamingStoreTag()
else:
from pytools import resolve_name
from pkgutil import resolve_name
try:
tag_cls = resolve_name(tag)
except ImportError:
Expand Down
15 changes: 9 additions & 6 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,19 +1445,22 @@ def count(kernel, set, space=None):
.drop_dims(dim_type.set, 0, set.dim(dim_type.set))
.add_dims(dim_type.set, 1))

if isinstance(set, isl.BasicSet):
set = set.to_set()
set = set.make_disjoint()

from loopy.isl_helpers import get_simple_strides

for bset in set.get_basic_sets():
bset_as_set = bset.to_set()
bset_count = None
bset_rebuilt = bset.universe(bset.space)

bset_strides = get_simple_strides(bset, key_by="index")

for i in range(bset.dim(isl.dim_type.set)):
dmax = bset.dim_max(i)
dmin = bset.dim_min(i)
dmax = bset_as_set.dim_max(i)
dmin = bset_as_set.dim_min(i)

stride = bset_strides.get((dim_type.set, i))
if stride is None:
Expand All @@ -1483,8 +1486,8 @@ def count(kernel, set, space=None):
dmax_matched = dmax.insert_dims(
dim_type.in_, 0, bset.dim(isl.dim_type.set))
for idx in range(bset.dim(isl.dim_type.set)):
if bset.has_dim_id(isl.dim_type.set, idx):
dim_id = bset.get_dim_id(isl.dim_type.set, idx)
if bset_as_set.has_dim_id(isl.dim_type.set, idx):
dim_id = bset_as_set.get_dim_id(isl.dim_type.set, idx)
dmin_matched = dmin_matched.set_dim_id(
isl.dim_type.in_, idx, dim_id)
dmax_matched = dmax_matched.set_dim_id(
Expand All @@ -1501,8 +1504,8 @@ def count(kernel, set, space=None):
if bset_count is not None:
total_count += bset_count

is_subset = bset <= bset_rebuilt
is_superset = bset >= bset_rebuilt
is_subset = bset_as_set <= bset_rebuilt
is_superset = bset_as_set >= bset_rebuilt

if not (is_subset and is_superset):
if is_subset:
Expand Down
4 changes: 2 additions & 2 deletions loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,11 @@ def make_new_loop_index(
dom
.eliminate(dt, idx, 1)
& aff_zero.le_set(aff_split_iname)
& aff_split_iname.lt_set(aligned_size)
& aff_split_iname.to_pw_aff().lt_set(aligned_size)
)

if not (
box_dom <= dom <= box_dom):
box_dom <= dom.to_set() <= box_dom):
raise LoopyError(
f"domain '{dom}' is not box-shape about iname "
f"'{split_iname}', cannot use chunk_iname()")
Expand Down
6 changes: 4 additions & 2 deletions loopy/transform/realize_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _try_infer_scan_and_sweep_bounds(

domain = domain.project_out_except(
{*within_inames, *kernel.non_iname_variable_names()},
(isl.dim_type.param,))
(isl.dim_type.param,)).to_set()

try:
sweep_lower_bound = domain.dim_min(sweep_idx)
Expand All @@ -535,7 +535,8 @@ def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound):
domain_with_sweep_param = _move_set_to_param_dims_except(domain, (scan_iname,))

domain_with_sweep_param = domain_with_sweep_param.project_out_except(
(sweep_iname, scan_iname), (dim_type.set, dim_type.param))
(sweep_iname, scan_iname), (dim_type.set, dim_type.param)
).to_set()

scan_iname_idx = domain_with_sweep_param.find_dim_by_name(
dim_type.set, scan_iname)
Expand Down Expand Up @@ -600,6 +601,7 @@ def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound):

def _get_domain_with_iname_as_param(domain, iname):
dim_type = isl.dim_type
domain = domain.to_set()

if domain.find_dim_by_name(dim_type.param, iname) >= 0:
return domain
Expand Down
2 changes: 1 addition & 1 deletion test/test_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_multi_arg_array_call(ctx_factory: cl.CtxFactory):
lp.Assignment(id="init1", assignee=acc_i,
expression="214748367"),
lp.Assignment(id="insn", assignee=index,
expression=p.If(p.Expression.eq(acc_i, a_i), i, index),
expression=p.If(p.ExpressionNode.eq(acc_i, a_i), i, index),
depends_on="update"),
lp.Assignment(id="update", assignee=acc_i,
expression=p.Variable("min")(acc_i, a_i),
Expand Down
Loading