Restore backbone-only .compile() override on forcefield regressors#165
Closed
timduignan wants to merge 2 commits into
Closed
Restore backbone-only .compile() override on forcefield regressors#165timduignan wants to merge 2 commits into
timduignan wants to merge 2 commits into
Conversation
Without an override, model.compile(...) falls through to
nn.Module.compile() which compiles self.__call__ (= forward). For
DirectForcefieldRegressor that has no effect at inference: predict()
calls self.model(batch) directly and never goes through __call__, so
the compiled callable is created but never invoked. For
ConservativeForcefieldRegressor predict() does flow through __call__,
but compiling the full regressor pulls the energy autograd backward
and (for OrbMol-v2) the Coulomb / PME path into the traced graph,
which dynamo fragments badly.
Restore the override on both classes so that .compile() compiles
self.model (the GNS message-passing backbone) where almost all the
FLOPs live; post-backbone work runs eager. This matches what the
internal core repo's orbmolv2 branch already does, and was the
behaviour before the porting commit that exposed this regression.
Verified on a single H100, orb_v3_direct_20_omat, 3 trials each
(variance ~1%):
variant 1k atoms 10k atoms
compile=True (pre-fix default) 9.23 ms 38.36 ms
compile=False (eager) 9.23 ms 38.32 ms
manual model.model.compile() (this fix) 7.07 ms 30.04 ms
i.e. the pre-fix .compile() was a no-op in inference for direct
models, and the restored backbone compile recovers a 22-24% inference
speedup.
Adds tests/forcefield/test_{direct,conservative}.py::test_compile_engages_backbone
that asserts the backbone has _compiled_call_impl set after .compile(),
so this can't silently re-break.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The replacement bullet I added in the previous commit framed the restored backbone compile as a feature with a ~1.3x speedup, which is misleading. Older releases already had backbone compile working via the same override; v0.7.0 silently broke it (df8f4f0) and the previous commit just restores prior behaviour. There's no new speedup to announce, so the cleanest thing is to drop the bullet entirely. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Aren't we overcorrecting here by removing the full compilation? Especially since it's passed our tests/evaluations for the conservative model. The direct regressor was indeed missing the compiled path via |
Contributor
Author
|
Brilliant thanks yep sorry yeah this is wrong deleting |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Without an override,
model.compile(...)falls through tonn.Module.compile()which compilesself.__call__(=forward). That has two problems:predict()callsself.model(batch)directly and never goes through__call__, so the compiled callable is created but never invoked..compile()is silently a no-op for inference.predict()does flow through__call__, but compiling the full regressor pulls the energy autograd backward and (for OrbMol-v2) the Coulomb / PME path into the traced graph, which dynamo fragments badly.This PR restores the small
.compile()override on both classes soself.model(the GNS message-passing backbone, where ~all the FLOPs live) is the thing actually compiled; post-backbone work runs eager. This matches what the internal core repo'sorbmolv2branch already does, and was the behaviour before the recent porting commit that exposed this regression.Numbers
orb_v3_direct_20_omaton a single H100, 3 trials each (variance ~1%):compile=True(pre-fix default)compile=False(eager)model.model.compile(...)(this fix)So the pre-fix
.compile()was a literal no-op in inference for direct models, and restoring the backbone compile recovers a 22-24% inference speedup. Conservative models also benefit by getting a clean single-fragment compile graph instead of the fragmented full-regressor trace.Tests
Adds
test_compile_engages_backboneto bothtests/forcefield/test_direct_regressor.pyandtests/forcefield/test_conservative.py. These assert thatmodel.model._compiled_call_impl is not Noneafter.compile(), so this can't silently re-break if someone removes the override again in a future refactor.Other changes
Test plan
pytest tests/forcefield/test_direct_regressor.py::test_compile_engages_backbone tests/forcefield/test_conservative.py::test_compile_engages_backbonepassestests/forcefield/test_{direct_regressor,conservative}.pypasses in CItest_regressor_compile_matches_eagerstill passes (compiled output equivalent to eager)🤖 Generated with Claude Code