Introduce more extensible scheme for extra properties.#464
Introduce more extensible scheme for extra properties.#464
Conversation
…h spins and charges
| spin: torch.Tensor | None = field(default=None) | ||
| system_idx: torch.Tensor | None = field(default=None) | ||
| _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 | ||
| _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) |
There was a problem hiding this comment.
could consider removing privileged positions of spin and charge
There was a problem hiding this comment.
It would honestly be a good test. If we can move spin and charge to extras without breaking anything that would help to validate this PR.
| static_state.store_model_extras(model_outputs) | ||
|
|
||
| props = trajectory_reporter.report(static_state, 0, model=model) | ||
|
|
There was a problem hiding this comment.
I would probably be more in favor of getting rid of trajectory_reporter here and just concatenate the results returned. We should just copy the logic of detecting whether the property is per atom or per system.
Trajectory_reporter have more functionalities than that but I'm not sure many people use it when running this static() function
There was a problem hiding this comment.
That is currently the default behavior but it's a bit convoluted. If no reporter is given:
staticwill make one when_configure_reporteris called.- When calling report the file IO will be bypassed in favor of a concatenation as you describe.
TrajectoryReporter probably needs a rework. Both to make it compatible with h5md (#19) and to make it easier to report out model properties, as you've identified.
For the latter, the simplest solution (perhaps not the best) would be to add the model_outputs as a optional kwarg to the report function so that model predictions can be covered by the current API without needing an additional call.
|
Nice work! I think that most of the time extra predicted properties are not useful in MD, so that |
I think here the aim was to just get something that would be general enough to be both input and output similar to Is this sufficient for the matris use-case? do you want to try explore that on top of this and then we can revisit this design? @Asecretboy would you be interested in adding a torchsim interface to https://github.com/HPC-AI-Team/MatRIS as an alternative to https://github.com/HPC-AI-Team/MatRIS/blob/c16f569ca08e6905e91b64e2ee68614303e46f7f/matris/applications/base.py#L24. TorchSim is an engine for batched MLIP workflows that enables fast MD with replicas/repeats or batched geometry optimization with minimal conceptual overhead c.f. ASE. |
|
Just putting this here for documentation purposes - but this is similar to another PR that I wrote and closed earlier: #354 |
|
I added a torch-sim interface HPC-AI-Team/MatRIS#5 |
There was a problem hiding this comment.
Broadly this seems like quite a reasonable interface. A few comments.
One question: how would we handle fluctuating extras? e.g. a user wants to applying a fluctuating E field to a system, it occurs to me we don't handle that but perhaps it's an advanced use case beyond the support of the default runners.
| # Concatenate extras | ||
| concatenated["_system_extras"] = { | ||
| key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() | ||
| } |
There was a problem hiding this comment.
This should be a stack operation I believe? If we move charge & spin to extras it will give us a already extant test.
| _system_extras: dict[str, torch.Tensor] = {} | ||
| if system_extras_keys: | ||
| for key in system_extras_keys: | ||
| vals = [at.info.get(key) for at in atoms_list] | ||
| if all(v is not None for v in vals): | ||
| _system_extras[key] = torch.tensor( | ||
| np.stack(vals), dtype=dtype, device=device | ||
| ) | ||
|
|
||
| _atom_extras: dict[str, torch.Tensor] = {} | ||
| if atom_extras_keys: | ||
| for key in atom_extras_keys: | ||
| arrays = [at.arrays.get(key) for at in atoms_list] | ||
| if all(a is not None for a in arrays): | ||
| _atom_extras[key] = torch.tensor( | ||
| np.concatenate(arrays), dtype=dtype, device=device | ||
| ) |
There was a problem hiding this comment.
This makes me a bit nervous, there is no guarantee that ASE model interfaces will need the same shape/structure as TorchSim model interfaces. Feels like it opens up an easy place for bugs to slide into the interface. This could be a reason to leave charge and spin in their privileged positions.
(confidence 50% on this opinion)
| def set_extras( | ||
| self, | ||
| key: str, | ||
| value: torch.Tensor, | ||
| scope: Literal["per-system", "per-atom"], | ||
| ) -> None: |
There was a problem hiding this comment.
this doesn't appear to be used anywhere?
| def has_extras(self, key: str) -> bool: | ||
| """Check if an extras key exists.""" | ||
| return key in self._system_extras or key in self._atom_extras |
There was a problem hiding this comment.
also doesn't appear to be used
This could be like temperature which the runners and there by the step functions can take as inputs at each step. The only difference is that the model also has to take them as inputs in this case not just the step functions. I think the latter is not dealt with right now? This is assuming something like the E field doesn't change the step function. |
This is a rough draft to try unlock #463