From 8959d8e91a9b4b3ff66e160ed0cf32e0316735ea Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 28 Nov 2024 00:25:09 -0600 Subject: [PATCH] Direct conn: drop support for actxs without broadcasting --- meshmode/discretization/connection/direct.py | 125 +++++++------------ test/test_meshmode.py | 14 +-- 2 files changed, 45 insertions(+), 94 deletions(-) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index a0763ccd..88065669 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -568,7 +568,6 @@ def _global_point_pick_info( def __call__( self, ary: ArrayOrContainerT, *, - _force_use_loopy: bool = False, _force_no_merged_batches: bool = False, ) -> ArrayOrContainerT: """ @@ -577,8 +576,8 @@ def __call__( coefficient data on :attr:`from_discr`. """ - # _force_use_loopy, _force_no_merged_batches: - # private arguments only used to ensure test coverage of all code paths. + # _force_no_merged_batches: + # private argument only used to ensure test coverage of all code paths. # {{{ recurse into array containers @@ -593,7 +592,6 @@ def __call__( else: return deserialize_container(ary, [ (key, self(subary, - _force_use_loopy=_force_use_loopy, _force_no_merged_batches=_force_no_merged_batches)) for key, subary in iterable ]) @@ -706,6 +704,9 @@ def group_pick_knl(is_surjective: bool): "idof": ConcurrentDOFInameTag()}) # }}} + if not actx.permits_advanced_indexing: + raise ValueError("Array context does not allow advanced indexing. " + "This is no longer supported.") group_arrays = [] for i_tgrp, (cgrp, group_pick_info) in enumerate( @@ -719,51 +720,33 @@ def group_pick_knl(is_surjective: bool): if group_pick_info is not None: group_array_contributions = [] - if actx.permits_advanced_indexing and not _force_use_loopy: - for fgpd in group_pick_info: - from_element_indices = actx.thaw(fgpd.from_element_indices) - - if ary[fgpd.from_group_index].size: - grp_ary_contrib = ary[fgpd.from_group_index][ - _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), - actx.thaw(fgpd.dof_pick_lists)[ - actx.thaw(fgpd.dof_pick_list_indices)] - ] - - if not fgpd.is_surjective: - from_el_present = actx.thaw(fgpd.from_el_present) - grp_ary_contrib = actx.np.where( + for fgpd in group_pick_info: + from_element_indices = actx.thaw(fgpd.from_element_indices) + + if ary[fgpd.from_group_index].size: + grp_ary_contrib = ary[fgpd.from_group_index][ _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), - grp_ary_contrib, - 0) - - # attach metadata - grp_ary_contrib = tag_axes( - actx, - {0: DiscretizationElementAxisTag(), - 1: DiscretizationDOFAxisTag()}, - grp_ary_contrib) - - group_array_contributions.append(grp_ary_contrib) - else: - for fgpd in group_pick_info: - group_knl_kwargs = {} + actx, from_element_indices, (-1, 1)), + actx.thaw(fgpd.dof_pick_lists)[ + actx.thaw(fgpd.dof_pick_list_indices)] + ] + if not fgpd.is_surjective: - group_knl_kwargs["from_el_present"] = \ - fgpd.from_el_present - - group_array_contributions.append( - actx.call_loopy( - group_pick_knl(fgpd.is_surjective), - dof_pick_lists=fgpd.dof_pick_lists, - dof_pick_list_indices=fgpd.dof_pick_list_indices, - ary=ary[fgpd.from_group_index], - from_element_indices=fgpd.from_element_indices, - nunit_dofs_tgt=( - self.to_discr.groups[i_tgrp].nunit_dofs), - **group_knl_kwargs)["result"]) + from_el_present = actx.thaw(fgpd.from_el_present) + grp_ary_contrib = actx.np.where( + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1)), + grp_ary_contrib, + 0) + + # attach metadata + grp_ary_contrib = tag_axes( + actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + grp_ary_contrib) + + group_array_contributions.append(grp_ary_contrib) group_array = sum(group_array_contributions) elif cgrp.batches: @@ -783,47 +766,25 @@ def group_pick_knl(is_surjective: bool): if point_pick_indices is None: grp_ary = ary[batch.from_group_index] mat = self._resample_matrix(actx, i_tgrp, i_batch) - if actx.permits_advanced_indexing and not _force_use_loopy: - batch_result = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), - actx.einsum("ij,ej->ei", - mat, grp_ary[from_element_indices]), - 0) - else: - batch_result = actx.call_loopy( - batch_mat_knl(), - resample_mat=mat, - ary=grp_ary, - from_el_present=from_el_present, - from_element_indices=from_element_indices, - nunit_dofs_tgt=( - self.to_discr.groups[i_tgrp].nunit_dofs) - )["result"] + batch_result = actx.np.where( + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1)), + actx.einsum("ij,ej->ei", + mat, grp_ary[from_element_indices]), + 0) else: from_vec = ary[batch.from_group_index] pick_list = actx.thaw(point_pick_indices) - if actx.permits_advanced_indexing and not _force_use_loopy: - batch_result = actx.np.where( + batch_result = actx.np.where( + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1)), + from_vec[ _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), - from_vec[ - _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), - pick_list], - 0) - else: - batch_result = actx.call_loopy( - batch_pick_knl(), - pick_list=pick_list, - ary=from_vec, - from_el_present=from_el_present, - from_element_indices=from_element_indices, - nunit_dofs_tgt=( - self.to_discr.groups[i_tgrp].nunit_dofs) - )["result"] + actx, from_element_indices, (-1, 1)), + pick_list], + 0) # attach metadata batch_result = tag_axes(actx, diff --git a/test/test_meshmode.py b/test/test_meshmode.py index f264b646..9bcb02eb 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -455,13 +455,8 @@ def f(x): bdry_f_2 = opp_face(bdry_f) # Ensure test coverage for alternate modes in DirectConnection - for force_loopy, force_no_merged_batches in [ - (False, True), - (True, False), - (True, True), - ]: + for force_no_merged_batches in [False, True]: bdry_f_2_alt = opp_face(bdry_f, - _force_use_loopy=force_loopy, _force_no_merged_batches=force_no_merged_batches) assert actx.to_numpy(flat_norm(bdry_f_2 - bdry_f_2_alt, np.inf)) < 1e-14 @@ -994,13 +989,8 @@ def grp_factory(mesh_el_group: MeshElementGroup): op_bdry_f = opposite(bdry_f) # Ensure test coverage for alternate modes in DirectConnection - for force_loopy, force_no_merged_batches in [ - (False, True), - (True, False), - (True, True), - ]: + for force_no_merged_batches in [False, True]: op_bdry_f_2 = opposite(bdry_f, - _force_use_loopy=force_loopy, _force_no_merged_batches=force_no_merged_batches) error = flat_norm(op_bdry_f - op_bdry_f_2, np.inf) assert actx.to_numpy(error) < 1e-15