From 28b8d938ef209e15d082b5793cb63f0a76feee7b Mon Sep 17 00:00:00 2001 From: Jas Kalayan Date: Thu, 12 Feb 2026 09:01:25 +0000 Subject: [PATCH 1/2] use vanilla PA for molecules that are one UA --- CodeEntropy/axes.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/CodeEntropy/axes.py b/CodeEntropy/axes.py index d698b48..d04b2fc 100644 --- a/CodeEntropy/axes.py +++ b/CodeEntropy/axes.py @@ -92,21 +92,25 @@ def get_UA_axes(self, data_container, index): moment_of_inertia: moment of inertia (3,) """ - index = int(index) + index = int(index) # bead index # use the same customPI trans axes as the residue level - UAs = data_container.select_atoms("mass 2 to 999") - UA_masses = self.get_UA_masses(data_container.atoms) - center = data_container.atoms.center_of_mass(unwrap=True) - moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( - center, UAs.positions, UA_masses, data_container.dimensions[:3] - ) - trans_axes, _moment_of_inertia = self.get_custom_principal_axes( - moment_of_inertia_tensor - ) + heavy_atoms = data_container.select_atoms("prop mass > 1.1") + if len(heavy_atoms) > 1: + UA_masses = self.get_UA_masses(data_container.atoms) + center = data_container.atoms.center_of_mass(unwrap=True) + moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( + center, heavy_atoms.positions, UA_masses, data_container.dimensions[:3] + ) + trans_axes, _moment_of_inertia = self.get_custom_principal_axes( + moment_of_inertia_tensor + ) + else: + # use standard PA for UA not bonded to anything else + make_whole(data_container.atoms) + trans_axes = data_container.atoms.principal_axes() # look for heavy atoms in residue of interest - heavy_atoms = data_container.select_atoms("prop mass > 1.1") heavy_atom_indices = [] for atom in heavy_atoms: heavy_atom_indices.append(atom.index) From 513664f0d78b1a962d2a26fa7d06c954b9e1950c Mon Sep 17 00:00:00 2001 From: Jas Kalayan Date: Thu, 12 Feb 2026 10:25:24 +0000 Subject: [PATCH 2/2] update tests for UA axes --- tests/test_CodeEntropy/test_axes.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/test_CodeEntropy/test_axes.py b/tests/test_CodeEntropy/test_axes.py index c5e2d38..b5f41aa 100644 --- a/tests/test_CodeEntropy/test_axes.py +++ b/tests/test_CodeEntropy/test_axes.py @@ -118,7 +118,8 @@ def test_get_residue_axes_bonded_default_axes_branch(self): np.testing.assert_allclose(center_out, center_expected) np.testing.assert_allclose(moi_out, np.array([3.0, 2.0, 1.0])) - def test_get_UA_axes_returns_expected_outputs(self): + @patch("CodeEntropy.axes.make_whole", autospec=True) + def test_get_UA_axes_returns_expected_outputs(self, mock_make_whole): """ Tests that: `get_UA_axes` returns expected UA axes. """ @@ -129,20 +130,17 @@ def test_get_UA_axes_returns_expected_outputs(self): dc.dimensions = np.array([1.0, 2.0, 3.0, 90.0, 90.0, 90.0]) dc.atoms.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) - uas = MagicMock() - uas.positions = np.zeros((2, 3)) - a0 = MagicMock() a0.index = 7 a1 = MagicMock() a1.index = 9 - heavy_atoms = [a0, a1] - heavy_ag = MagicMock() - heavy_ag.positions = np.array([[9.9, 8.8, 7.7]]) - heavy_ag.__getitem__.return_value = MagicMock() + heavy_atoms = MagicMock() + heavy_atoms.__len__.return_value = 2 + heavy_atoms.__iter__.return_value = iter([a0, a1]) + heavy_atoms.positions = np.array([[9.9, 8.8, 7.7], [1.1, 2.2, 3.3]]) - dc.select_atoms.side_effect = [uas, heavy_atoms, heavy_ag] + dc.select_atoms.side_effect = [heavy_atoms, heavy_atoms] axes.get_UA_masses = MagicMock(return_value=[1.0, 1.0]) axes.get_moment_of_inertia_tensor = MagicMock(return_value=np.eye(3)) @@ -160,13 +158,12 @@ def test_get_UA_axes_returns_expected_outputs(self): np.testing.assert_array_equal(trans_axes, trans_axes_expected) np.testing.assert_array_equal(rot_axes, rot_axes_expected) - np.testing.assert_array_equal(center, heavy_ag.positions[0]) + np.testing.assert_array_equal(center, heavy_atoms.positions[0]) np.testing.assert_array_equal(moi, moi_expected) calls = [c.args[0] for c in dc.select_atoms.call_args_list] - assert calls[0] == "mass 2 to 999" - assert calls[1] == "prop mass > 1.1" - assert calls[2] == "index 9" + assert calls[0] == "prop mass > 1.1" + assert calls[1] == "index 9" def test_get_bonded_axes_returns_none_for_light_atom(self): """