-
Notifications
You must be signed in to change notification settings - Fork 248
api: fix interp/eval of expressions #2843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,13 +95,36 @@ def dtype(self): | |
|
|
||
| @cached_property | ||
| def indices(self): | ||
| return tuple(filter_ordered(flatten(getattr(i, 'indices', ()) | ||
| for i in self._args_diff))) | ||
| if not self._args_diff: | ||
| return DimensionTuple() | ||
|
|
||
| # Get indices of all args and merge them | ||
| mapper = {} | ||
| for a in self._args_diff: | ||
| for d, i in a.indices.getters.items(): | ||
FabioLuporini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mapper.setdefault(d, []).append(i) | ||
|
|
||
| # Filter unique indices | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpicking: blank line |
||
| mapper = {k: v[0] if len(v) == 1 else tuple(filter_ordered(v)) | ||
| for k, v in mapper.items()} | ||
|
|
||
| return DimensionTuple(*mapper.values(), getters=tuple(mapper.keys())) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpicking: blank line
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpicking: can just be |
||
|
|
||
| @cached_property | ||
| def dimensions(self): | ||
| return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ()) | ||
| for i in self._args_diff))) | ||
| if not self._args_diff: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need these two lines?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because a generic Differentiable does not necessarily has an
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but the property is still there, it's not that you are exposing yourself to an AttributeError... so
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An empty tuple doesn't have a |
||
| return DimensionTuple() | ||
|
|
||
| # Use the staggering of the highest priority function | ||
| return highest_priority(self).dimensions | ||
|
|
||
| @cached_property | ||
| def staggered(self): | ||
| if not self._args_diff: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same story as before, I think ... unless I'm missing something
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| return None | ||
|
|
||
| # Use the staggering of the highest priority function | ||
| return highest_priority(self).staggered | ||
|
|
||
| @cached_property | ||
| def root_dimensions(self): | ||
|
|
@@ -117,11 +140,6 @@ def indices_ref(self): | |
| return DimensionTuple(*self.dimensions, getters=self.dimensions) | ||
| return highest_priority(self).indices_ref | ||
|
|
||
| @cached_property | ||
| def staggered(self): | ||
| return tuple(filter_ordered(flatten(getattr(i, 'staggered', ()) | ||
| for i in self._args_diff))) | ||
|
|
||
| @cached_property | ||
| def is_Staggered(self): | ||
| return any([getattr(i, 'is_Staggered', False) for i in self._args_diff]) | ||
|
|
@@ -474,13 +492,21 @@ def has_free(self, *patterns): | |
| return all(i in self.free_symbols for i in patterns) | ||
|
|
||
|
|
||
| def highest_priority(DiffOp): | ||
| def highest_priority(diff_op): | ||
| if not diff_op._args_diff: | ||
| return diff_op | ||
|
|
||
| # We want to get the object with highest priority | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line |
||
| # We also need to make sure that the object with the largest | ||
| # set of dimensions is used when multiple ones with the same | ||
| # priority appear | ||
| prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions)) | ||
| return sorted(DiffOp._args_diff, key=prio, reverse=True)[0] | ||
| prio_func = sorted(diff_op._args_diff, key=prio, reverse=True)[0] | ||
|
|
||
| # The highest priority must be a Function | ||
| if not isinstance(prio_func, AbstractFunction): | ||
| return highest_priority(prio_func) | ||
| return prio_func | ||
|
|
||
|
|
||
| class DifferentiableOp(Differentiable): | ||
|
|
@@ -548,8 +574,11 @@ class DifferentiableFunction(DifferentiableOp): | |
| def __new__(cls, *args, **kwargs): | ||
| return cls.__sympy_class__.__new__(cls, *args, **kwargs) | ||
|
|
||
| def _eval_at(self, func): | ||
| return self | ||
| @property | ||
| def _fd_priority(self): | ||
| if highest_priority(self) is self: | ||
| return super()._fd_priority | ||
| return highest_priority(self)._fd_priority | ||
|
|
||
|
|
||
| class Add(DifferentiableOp, sympy.Add): | ||
|
|
@@ -633,26 +662,12 @@ def _gather_for_diff(self): | |
| if len(set(f.staggered for f in self._args_diff)) == 1: | ||
| return self | ||
|
|
||
| func_args = highest_priority(self) | ||
| new_args = [] | ||
| ref_inds = func_args.indices_ref.getters | ||
|
|
||
| for f in self.args: | ||
| if f not in self._args_diff \ | ||
| or f is func_args \ | ||
| or isinstance(f, DifferentiableFunction): | ||
| new_args.append(f) | ||
| else: | ||
| ind_f = f.indices_ref.getters | ||
| mapper = {ind_f.get(d, d): ref_inds.get(d, d) | ||
| for d in self.dimensions | ||
| if ind_f.get(d, d) is not ref_inds.get(d, d)} | ||
| if mapper: | ||
| new_args.append(f.subs(mapper)) | ||
| else: | ||
| new_args.append(f) | ||
|
|
||
| return self.func(*new_args, evaluate=False) | ||
| derivs, other = split(self.args, lambda a: isinstance(a, sympy.Derivative)) | ||
| if len(derivs) == 0: | ||
| return self._eval_at(highest_priority(self)) | ||
| else: | ||
| other = self.func(*other)._eval_at(highest_priority(self)) | ||
| return self.func(other, *derivs) | ||
|
|
||
|
|
||
| class Pow(DifferentiableOp, sympy.Pow): | ||
|
|
@@ -1034,6 +1049,9 @@ def __new__(cls, *args, base=None, **kwargs): | |
| obj = super().__new__(cls, *args, **kwargs) | ||
|
|
||
| try: | ||
| if base is obj: | ||
| # In some rare cases (rebuild?) base may be obj itself | ||
| base = base.base | ||
| obj.base = base | ||
| except AttributeError: | ||
| # This might happen if e.g. one attempts a (re)construction with | ||
|
|
@@ -1061,6 +1079,10 @@ def _eval_at(self, func): | |
| # and should not be re-evaluated at a different location | ||
| return self | ||
|
|
||
| @property | ||
| def indices_ref(self): | ||
| return self.base.indices_ref | ||
|
|
||
|
|
||
| class diffify: | ||
|
|
||
|
|
@@ -1184,6 +1206,29 @@ def _(expr, x0, **kwargs): | |
| return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs)) | ||
|
|
||
|
|
||
| @interp_for_fd.register(Mul) | ||
| def _(expr, x0, **kwargs): | ||
| # For a Mul expression, we interpolate the whole expression | ||
| # Do we actually need interpolation | ||
| if all(expr.indices[d] is i for d, i in x0.items()): | ||
| return expr | ||
|
|
||
| # Split args between those that need interp and those that don't | ||
| def test0(a): | ||
| return all(a.indices[d] is i for d, i in x0.items() if d in a.dimensions) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: blank line |
||
|
|
||
| oa, ia = split(expr._args_diff, | ||
| lambda a: isinstance(a, sympy.Derivative) or test0(a)) | ||
| oa = oa + tuple(a for a in expr.args if a not in expr._args_diff) | ||
|
|
||
| # Interpolate the necessary args | ||
| d_dims = tuple((d, 0) for d in x0) | ||
| fd_order = tuple(expr.interp_order for d in x0) | ||
| iexpr = expr.func(*ia).diff(*d_dims, fd_order=fd_order, x0=x0, **kwargs) | ||
|
|
||
| return expr.func(iexpr, *oa) | ||
|
|
||
|
|
||
| @interp_for_fd.register(sympy.Expr) | ||
| def _(expr, x0, **kwargs): | ||
| if expr.args: | ||
|
|
@@ -1194,7 +1239,8 @@ def _(expr, x0, **kwargs): | |
|
|
||
| @interp_for_fd.register(AbstractFunction) | ||
| def _(expr, x0, **kwargs): | ||
| x0_expr = {d: v for d, v in x0.items() if v.has(d)} | ||
| x0_expr = {d: v for d, v in x0.items() if v.has(d) | ||
| and expr.indices[d] is not v} | ||
| if x0_expr: | ||
| return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()}) | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,10 +50,9 @@ def check_input(func): | |
| def wrapper(expr, *args, **kwargs): | ||
| try: | ||
| return S.Zero if expr.is_Number else func(expr, *args, **kwargs) | ||
| except AttributeError: | ||
| raise ValueError( | ||
| f"'{expr}' must be of type Differentiable, not {type(expr)}" | ||
| ) from None | ||
| except Exception as e: | ||
| raise type(e)(f"Error while computing finite-difference for expr={expr}: " | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reporting the type of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what you mean the message is |
||
| f"{e}") from e | ||
| return wrapper | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpicking: blank line