Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions core/optimizer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,54 @@ def compute_mad(results: List[FitAttempt]) -> Optional[np.ndarray]:
return mad


def compute_mean_params(parameter_samples: dict[str, np.ndarray]) -> dict[str, float]:
"""Mean of each parameter across the accepted-fit pool.

Classical (non-robust) counterpart to :func:`compute_median_params`. It
operates on the pooled ``FitResult.parameter_samples`` mapping (one flat
array of accepted-fit values per parameter key) rather than raw
``FitAttempt`` objects, because that pool is what downstream consumers
(the fitted-parameters table, distribution plots) actually carry.

Parameters
----------
parameter_samples : dict[str, np.ndarray]
One flat array of accepted-fit values per parameter key.

Returns
-------
dict[str, float]
Parameter key -> arithmetic mean of its accepted-fit values.
"""
return {key: float(np.mean(values)) for key, values in parameter_samples.items()}


def compute_std_params(parameter_samples: dict[str, np.ndarray]) -> dict[str, float]:
"""Standard deviation of each parameter across the accepted-fit pool.

Classical (non-robust) counterpart to :func:`compute_mad`. Uses the sample
standard deviation (``ddof=1``) — the conventional unbiased estimator that
pairs with the arithmetic mean. A single accepted fit has no estimable
spread and returns ``0.0`` for that key, mirroring :func:`compute_mad`
returning 0 for a single sample.

Parameters
----------
parameter_samples : dict[str, np.ndarray]
One flat array of accepted-fit values per parameter key.

Returns
-------
dict[str, float]
Parameter key -> sample standard deviation of its accepted-fit values.
"""
std_params: dict[str, float] = {}
for key, values in parameter_samples.items():
arr = np.asarray(values, dtype=float)
std_params[key] = float(np.std(arr, ddof=1)) if arr.size > 1 else 0.0
return std_params


def aggregate_fits(
results: List[FitAttempt],
rmse_threshold_factor: float = 1.5,
Expand Down
182 changes: 107 additions & 75 deletions gui/plotting/fit_summary_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,61 @@
)

from core.assays.registry import ASSAY_REGISTRY, AssayType
from core.optimizer.filters import compute_mean_params, compute_std_params
from core.pipeline.fit_pipeline import FitResult
from core.units import Q_, Quantity
from gui.plotting.labels import fmt_param, fmt_unit_html
from gui.widgets.info_button import InfoGroupBox

_UNCERTAINTY_HELP_HTML = """
<h3>Uncertainty &mdash; what the &plusmn; value means</h3>

<p>The &plusmn; value next to each fitted parameter tells you <b>how
much that parameter varies</b> across all the acceptable fits the
fitter found. A small &plusmn; means the fitter consistently lands on
the same value; a large one means there is real spread.</p>
_HEADERS = ['Parameter', 'Median \u00b1 MAD', 'Mean \u00b1 STDEV', 'Units']

<p>Technically, the reported value is the <i>median</i> of the
acceptable fits, and the &plusmn; is their <i>median absolute
deviation</i> &mdash; a robust measure of spread that is not thrown off
by a few outliers.</p>
_UNCERTAINTY_HELP_HTML = """
<h3>The two &plusmn; summaries &mdash; what they mean</h3>

<p>Each parameter is summarised across <b>all accepted fits</b> the fitter
found (no re-fitting). Two pairs are shown so you can judge both the typical
value and how much it varies:</p>

<ul>
<li><b>Median &plusmn; MAD</b> &mdash; the <i>robust</i> summary. The median is
the middle value; the MAD (median absolute deviation) is the spread around it.
Both shrug off a few stray fits, so this pair stays stable even when the pool
has outliers or is skewed. The <i>Median Fit</i> curve drawn on the plot uses
these median values.</li>
<li><b>Mean &plusmn; STDEV</b> &mdash; the <i>classical</i> summary. The mean is
the arithmetic average; the STDEV (standard deviation) is the textbook spread.
Both weight every fit equally, so they are the familiar numbers &mdash; but a
single outlying fit pulls them noticeably.</li>
</ul>

<h4>MAD is not the same as STDEV</h4>
<p>They measure spread in different ways. For a clean, bell-shaped (Gaussian)
pool they agree after a fixed scaling: <b>STDEV &asymp; 1.48 &times; MAD</b>.
When the two pairs roughly satisfy that, the accepted fits are well behaved and
either summary is fine.</p>
<p>When they <i>disagree</i> &mdash; typically STDEV much larger than
1.48&nbsp;&times;&nbsp;MAD, or the mean sitting far from the median &mdash; a
few outlying or skewed fits are inflating the classical numbers. Trust the
robust <b>Median &plusmn; MAD</b> in that case, and inspect the spread in the
distribution (box-and-whisker) plot.</p>

<h4>Average mode</h4>
<p>Your replicas are averaged into one curve, which is then fit many
times from different starting points. The &plusmn; reflects how
precisely the fitter can pin down the parameter on that averaged curve.
This is a measure of <b>numerical precision</b> &mdash; it does not
capture replica-to-replica variation.</p>
<p>Your replicas are averaged into one curve, which is then fit many times from
different starting points. The spread reflects how precisely the fitter can pin
down the parameter on that averaged curve &mdash; a measure of <b>numerical
precision</b>. It does not capture replica-to-replica variation.</p>

<h4>Per-replica mode</h4>
<p>Each replica is fit independently, and every acceptable fit from
every replica is collected together. The &plusmn; now reflects the
full spread &mdash; including differences between replicas. This is a
measure of <b>experimental reproducibility</b>, which is typically the
number you would report in a publication.</p>

<h4>The Median Fit curve</h4>
<p>The curve drawn on the plot uses the median parameter values from all
acceptable fits. It is labelled <i>Median Fit</i> because it
represents the middle of the distribution, not the single &ldquo;best&rdquo;
attempt.</p>
<p>Each replica is fit independently and every acceptable fit from every
replica is pooled together. The spread now includes differences between
replicas &mdash; a measure of <b>experimental reproducibility</b>, typically
the number you would report in a publication.</p>

<h4>Which mode am I using?</h4>
<p>The column header tells you: <i>&plusmn;&nbsp;Uncertainty
(optimiser)</i> in average mode, or <i>&plusmn;&nbsp;Uncertainty
(pool&nbsp;N=&hellip;, &hellip;&nbsp;replicas)</i> in per-replica
mode. Switch between modes in Fit Configuration &rarr; &ldquo;Fit per
replica&rdquo;.</p>
<p>The caption under the table says so &mdash; <i>average mode</i> or
<i>per-replica mode</i> &mdash; along with how many accepted fits (N) the
statistics were computed from. Switch modes in Fit Configuration &rarr;
&ldquo;Fit per replica&rdquo;.</p>
"""


Expand All @@ -82,13 +93,34 @@
"""


def _magnitude(value) -> float:
"""Return the float magnitude of a Quantity or plain number."""
return float(value.magnitude) if isinstance(value, Quantity) else float(value)


def _fmt_mag(magnitude: float, unit_str: str) -> str:
"""Format a magnitude for display, stripping any unit (shown separately)."""
if unit_str:
return f'{Q_(magnitude, unit_str):.3g~H}'.rsplit(' ', 1)[0]
return f'{magnitude:.3g}'
Comment on lines +101 to +105


def _make_cell(html: str) -> QLabel:
"""Build a centred rich-text cell label for the parameters table."""
lbl = QLabel(html)
lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
lbl.setTextFormat(Qt.TextFormat.RichText)
return lbl


class FitSummaryWidget(QWidget):
"""Read-only display of ``FitResult`` statistics.

Layout
------
- ``QGroupBox("Fitted Parameters")`` with columns:
Parameter | Value | +/- Uncertainty | Units
Parameter | Median +/- MAD | Mean +/- STDEV | Units, plus a caption
noting the fit mode and accepted-fit count.
- ``QGroupBox("Fit Quality")`` with RMSE, R-squared, Fits passing.
"""

Expand All @@ -97,18 +129,21 @@ def __init__(self, parent=None):

self._params_group = InfoGroupBox(
'Fitted Parameters',
'Uncertainty: what the \u00b1 value means',
'Median \u00b1 MAD vs Mean \u00b1 STDEV',
_UNCERTAINTY_HELP_HTML,
)
self._table = QTableWidget(0, 4)
self._uncertainty_header_default = '\u00b1 Uncertainty (optimiser)'
self._table.setHorizontalHeaderLabels(['Parameter', 'Value', self._uncertainty_header_default, 'Units'])
self._table = QTableWidget(0, len(_HEADERS))
self._table.setHorizontalHeaderLabels(_HEADERS)
header = self._table.horizontalHeader()
header.setStretchLastSection(False)
header.setSectionResizeMode(QHeaderView.ResizeMode.Interactive)
self._table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
self._caption = QLabel()
self._caption.setWordWrap(True)
self._caption.setStyleSheet('color: gray;')
params_layout = QVBoxLayout(self._params_group)
params_layout.addWidget(self._table)
params_layout.addWidget(self._caption)

self._quality_group = InfoGroupBox('Fit Quality', 'Fit Quality', _FIT_QUALITY_HELP_HTML)
quality_layout = QFormLayout(self._quality_group)
Expand All @@ -127,6 +162,9 @@ def update_result(self, result: FitResult) -> None:
"""Populate the widget from a ``FitResult``.

Resolves units from ``ASSAY_REGISTRY`` using ``result.assay_type``.
Shows two stat pairs per parameter: the robust median/MAD (reusing the
stored values) and the classical mean/STDEV (computed from
``result.parameter_samples``; no re-fit).

Parameters
----------
Expand All @@ -137,54 +175,47 @@ def update_result(self, result: FitResult) -> None:
if assay_type is not None:
units = ASSAY_REGISTRY[assay_type].units

# The fit mode and pool size used to live in the uncertainty-column
# header; with two stat pairs that context moves to a caption.
if result.uncertainty_source == 'replicate': # JSON-compat magic value
pool_size = result.metadata.get('pool_size', result.n_passing)
n_reps = result.metadata.get('n_replicas_fit', '?')
header = f'\u00b1 Uncertainty (pool N={pool_size}, {n_reps} replicas)'
self._caption.setText(
f'Statistics across N = {pool_size} accepted fits pooled from {n_reps} replicas (per-replica mode).'
)
else:
header = self._uncertainty_header_default
self._table.setHorizontalHeaderLabels(['Parameter', 'Value', header, 'Units'])
self._caption.setText(f'Statistics across N = {result.n_passing} accepted fits (average mode).')

params = result.parameters
uncertainties = result.uncertainties
samples = result.parameter_samples
means = compute_mean_params(samples) if samples else None
stds = compute_std_params(samples) if samples else None

self._table.setRowCount(len(params))
for row, (key, value) in enumerate(params.items()):
unc = uncertainties.get(key, float('nan'))
for row, key in enumerate(params):
unit_str = units.get(key, '')

lbl_name = QLabel(fmt_param(key))
lbl_name.setAlignment(Qt.AlignmentFlag.AlignCenter)
self._table.setCellWidget(row, 0, lbl_name)

val_mag = float(value.magnitude) if isinstance(value, Quantity) else float(value)
unc_mag = float(unc.magnitude) if isinstance(unc, Quantity) else float(unc)

# Use Pint HTML formatter for proper superscript notation
if unit_str:
val_html = f'{Q_(val_mag, unit_str):.3g~H}'
unc_html = f'{Q_(unc_mag, unit_str):.3g~H}'
median_mag = _magnitude(params[key])
mad_mag = _magnitude(uncertainties.get(key, float('nan')))

self._table.setCellWidget(row, 0, _make_cell(fmt_param(key)))
self._table.setCellWidget(
row,
1,
_make_cell(f'{_fmt_mag(median_mag, unit_str)} \u00b1 {_fmt_mag(mad_mag, unit_str)}'),
)

if means is not None and key in means:
mean_cell = _make_cell(f'{_fmt_mag(means[key], unit_str)} \u00b1 {_fmt_mag(stds[key], unit_str)}')
else:
val_html = f'{val_mag:.3g}'
unc_html = f'{unc_mag:.3g}'
# Strip unit from the HTML — units shown in separate column
val_display = val_html.rsplit(' ', 1)[0] if unit_str else val_html
unc_display = unc_html.rsplit(' ', 1)[0] if unit_str else unc_html

lbl_val = QLabel(val_display)
lbl_val.setAlignment(Qt.AlignmentFlag.AlignCenter)
lbl_val.setTextFormat(Qt.TextFormat.RichText)
self._table.setCellWidget(row, 1, lbl_val)

lbl_unc = QLabel(unc_display)
lbl_unc.setAlignment(Qt.AlignmentFlag.AlignCenter)
lbl_unc.setTextFormat(Qt.TextFormat.RichText)
self._table.setCellWidget(row, 2, lbl_unc)

unit_html = fmt_unit_html(unit_str)
lbl_unit = QLabel(unit_html)
lbl_unit.setAlignment(Qt.AlignmentFlag.AlignCenter)
self._table.setCellWidget(row, 3, lbl_unit)
mean_cell = _make_cell('\u2014')
mean_cell.setToolTip(
'Mean \u00b1 STDEV needs the accepted-fit pool, which is '
'unavailable for this result (e.g. imported from an older file).'
)
self._table.setCellWidget(row, 2, mean_cell)

self._table.setCellWidget(row, 3, _make_cell(fmt_unit_html(unit_str)))

rmse_html = f'{Q_(result.rmse, "au"):.3g~H}'
self._rmse_label.setTextFormat(Qt.TextFormat.RichText)
Expand Down Expand Up @@ -218,7 +249,8 @@ def _autosize_columns(self) -> None:
def clear(self) -> None:
"""Reset all fields to their empty state."""
self._table.setRowCount(0)
self._table.setHorizontalHeaderLabels(['Parameter', 'Value', self._uncertainty_header_default, 'Units'])
self._table.setHorizontalHeaderLabels(_HEADERS)
self._caption.clear()
self._rmse_label.setText('\u2014')
self._r2_label.setText('\u2014')
self._passing_label.setText('\u2014')
Expand Down
Loading