diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 87b2aedc8..36926d3d5 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -21227,7 +21227,15 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 16, - "endColumn": 40, + "endColumn": 30, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 49, "lineCount": 1 } }, @@ -21477,7 +21485,15 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 20, - "endColumn": 43, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 20, + "endColumn": 52, "lineCount": 1 } }, @@ -21485,15 +21501,15 @@ "code": "reportUnknownArgumentType", "range": { "startColumn": 20, - "endColumn": 73, + "endColumn": 82, "lineCount": 1 } }, { "code": "reportUnknownArgumentType", "range": { - "startColumn": 60, - "endColumn": 72, + "startColumn": 69, + "endColumn": 81, "lineCount": 1 } }, @@ -53877,6 +53893,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 22, + "endColumn": 33, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -53896,8 +53920,8 @@ { "code": "reportUnknownArgumentType", "range": { - "startColumn": 42, - "endColumn": 46, + "startColumn": 37, + "endColumn": 47, "lineCount": 1 } }, @@ -53921,7 +53945,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 19, - "endColumn": 31, + "endColumn": 38, "lineCount": 1 } }, @@ -53929,7 +53953,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 19, - "endColumn": 31, + "endColumn": 38, "lineCount": 1 } }, @@ -53949,6 +53973,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 57, + "endColumn": 62, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -53997,6 +54029,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 37, + "endColumn": 63, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -54013,6 +54053,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 37, + "endColumn": 63, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -54033,7 +54081,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 19, - "endColumn": 34, + "endColumn": 41, "lineCount": 1 } }, @@ -54041,7 +54089,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 29, - "endColumn": 44, + "endColumn": 51, "lineCount": 1 } }, @@ -54053,6 +54101,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 51, + "endColumn": 57, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -54061,6 +54117,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 51, + "endColumn": 57, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -54093,6 +54157,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 51, + "endColumn": 57, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -90429,6 +90501,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 13, + "endColumn": 26, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -101881,22 +101961,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 36, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 36, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index ec49e2c3f..7c16c0011 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -187,7 +187,7 @@ def simplify_pw_aff(pw_aff, context=None): if i == j: continue - if aff_i.gist(dom_j).is_equal(aff_j): + if aff_i.gist(dom_j).to_pw_aff().is_equal(aff_j): # aff_i is sufficient to cover aff_j, eliminate aff_j new_pieces = pieces[:] if i < j: @@ -895,7 +895,7 @@ def find_and_rename_dim(isl_obj, dt, old_name, new_name): """ return isl_obj.set_dim_name( - dt, isl_obj.find_dim_by_name(dt, old_name), new_name) + dt, isl_obj.to_set().find_dim_by_name(dt, old_name), new_name) # }}} diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 554d17b1e..af51bebbd 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -463,7 +463,7 @@ def combine_domains(self, domains: Sequence[int]) -> isl.BasicSet: dim_type.set, result.dim(dim_type.set), dim_type.param, - result.find_dim_by_name(dim_type.param, actual_iname), + result.to_set().find_dim_by_name(dim_type.param, actual_iname), 1) return result diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index c575367fb..67f2c0fb0 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -125,7 +125,7 @@ def _normalize_string_tag(tag): if tag == "!streaming_store": return UseStreamingStoreTag() else: - from pytools import resolve_name + from pkgutil import resolve_name try: tag_cls = resolve_name(tag) except ImportError: diff --git a/loopy/statistics.py b/loopy/statistics.py index 3fd6568eb..99010c21e 100644 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1445,19 +1445,22 @@ def count(kernel, set, space=None): .drop_dims(dim_type.set, 0, set.dim(dim_type.set)) .add_dims(dim_type.set, 1)) + if isinstance(set, isl.BasicSet): + set = set.to_set() set = set.make_disjoint() from loopy.isl_helpers import get_simple_strides for bset in set.get_basic_sets(): + bset_as_set = bset.to_set() bset_count = None bset_rebuilt = bset.universe(bset.space) bset_strides = get_simple_strides(bset, key_by="index") for i in range(bset.dim(isl.dim_type.set)): - dmax = bset.dim_max(i) - dmin = bset.dim_min(i) + dmax = bset_as_set.dim_max(i) + dmin = bset_as_set.dim_min(i) stride = bset_strides.get((dim_type.set, i)) if stride is None: @@ -1483,8 +1486,8 @@ def count(kernel, set, space=None): dmax_matched = dmax.insert_dims( dim_type.in_, 0, bset.dim(isl.dim_type.set)) for idx in range(bset.dim(isl.dim_type.set)): - if bset.has_dim_id(isl.dim_type.set, idx): - dim_id = bset.get_dim_id(isl.dim_type.set, idx) + if bset_as_set.has_dim_id(isl.dim_type.set, idx): + dim_id = bset_as_set.get_dim_id(isl.dim_type.set, idx) dmin_matched = dmin_matched.set_dim_id( isl.dim_type.in_, idx, dim_id) dmax_matched = dmax_matched.set_dim_id( @@ -1501,8 +1504,8 @@ def count(kernel, set, space=None): if bset_count is not None: total_count += bset_count - is_subset = bset <= bset_rebuilt - is_superset = bset >= bset_rebuilt + is_subset = bset_as_set <= bset_rebuilt + is_superset = bset_as_set >= bset_rebuilt if not (is_subset and is_superset): if is_subset: diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 9d06d7a5a..489af61c9 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -541,11 +541,11 @@ def make_new_loop_index( dom .eliminate(dt, idx, 1) & aff_zero.le_set(aff_split_iname) - & aff_split_iname.lt_set(aligned_size) + & aff_split_iname.to_pw_aff().lt_set(aligned_size) ) if not ( - box_dom <= dom <= box_dom): + box_dom <= dom.to_set() <= box_dom): raise LoopyError( f"domain '{dom}' is not box-shape about iname " f"'{split_iname}', cannot use chunk_iname()") diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py index 9e064bc6c..459af6d0a 100644 --- a/loopy/transform/realize_reduction.py +++ b/loopy/transform/realize_reduction.py @@ -511,7 +511,7 @@ def _try_infer_scan_and_sweep_bounds( domain = domain.project_out_except( {*within_inames, *kernel.non_iname_variable_names()}, - (isl.dim_type.param,)) + (isl.dim_type.param,)).to_set() try: sweep_lower_bound = domain.dim_min(sweep_idx) @@ -535,7 +535,8 @@ def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound): domain_with_sweep_param = _move_set_to_param_dims_except(domain, (scan_iname,)) domain_with_sweep_param = domain_with_sweep_param.project_out_except( - (sweep_iname, scan_iname), (dim_type.set, dim_type.param)) + (sweep_iname, scan_iname), (dim_type.set, dim_type.param) + ).to_set() scan_iname_idx = domain_with_sweep_param.find_dim_by_name( dim_type.set, scan_iname) @@ -600,6 +601,7 @@ def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound): def _get_domain_with_iname_as_param(domain, iname): dim_type = isl.dim_type + domain = domain.to_set() if domain.find_dim_by_name(dim_type.param, iname) >= 0: return domain diff --git a/test/test_callables.py b/test/test_callables.py index 8f82621cf..fd04ca332 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -269,7 +269,7 @@ def test_multi_arg_array_call(ctx_factory: cl.CtxFactory): lp.Assignment(id="init1", assignee=acc_i, expression="214748367"), lp.Assignment(id="insn", assignee=index, - expression=p.If(p.Expression.eq(acc_i, a_i), i, index), + expression=p.If(p.ExpressionNode.eq(acc_i, a_i), i, index), depends_on="update"), lp.Assignment(id="update", assignee=acc_i, expression=p.Variable("min")(acc_i, a_i),