From d45f4b23837d2733fc11d4f551190bdf11464cf1 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:37:38 -0400 Subject: [PATCH 01/21] Document the DiLoCo implementation contract. Adds the faithful DiLoCo design contract for Plato, including the server-side outer optimizer sign convention, exact local-step H semantics, small-H sampling requirements, client-local optimizer and scheduler state ownership, and the implementation dependency graph.\n\nCovers Linear issue DT-408. --- docs/docs/development/diloco.md | 220 ++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 4 +- 2 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 docs/docs/development/diloco.md diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md new file mode 100644 index 000000000..94dfd49fc --- /dev/null +++ b/docs/docs/development/diloco.md @@ -0,0 +1,220 @@ +# DiLoCo Design Contract + +This note defines what Plato will call faithful DiLoCo for the initial +implementation. It is a contract for the implementation issues that follow; it +does not describe runtime behavior that already exists in Plato. + +Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo +training loop inside Plato's federated runtime. It does not mean reproducing +the paper's exact C4 dataset, model scale, tokenizer, hardware topology, +pretraining duration, or final benchmark numbers. + +## Algorithm Contract + +DiLoCo has two optimizer levels: + +- The client-local inner optimizer trains each selected logical client for + exactly `H` local optimizer steps between synchronizations. +- The server-side outer optimizer updates the global model from the averaged + outer gradient. + +Plato's FedAvg-style model delta is: + +```text +plato_delta = client_after - global_before +``` + +DiLoCo's outer gradient is: + +```text +outer_gradient = global_before - client_after = -plato_delta +``` + +The DiLoCo server must still return a Plato-compatible model delta because +`algorithm.update_weights()` adds the returned delta to the current global +model. For example, outer SGD with learning rate `1.0` returns the averaged +Plato delta and is equivalent to FedAvg only when the same averaging rule is +used. + +The outer optimizer runs on the server. Clients run only the inner optimizer +and send model weights or weight-equivalent updates. Client-local optimizer and +scheduler state persists per logical client and is never sent to the server. + +## Local Work `H` + +`H` means client-local optimizer steps between synchronizations. It is not: + +- epochs, +- raw dataloader batches, or +- gradient-accumulation micro-batches. + +When gradient accumulation is enabled, `H` counts completed optimizer steps. +Raw batches that do not trigger `optimizer.step()` do not increment `H`. + +`H` may be smaller than one epoch. Faithful DiLoCo must therefore stop local +training mid-epoch after exactly `H` optimizer steps. This early stop must +still run normal trainer cleanup, state persistence, callback completion, and +reporting paths. It must not perform an extra final optimizer step. + +Small-`H` training must not repeatedly replay the same first `H` batches only +because the train loader is recreated each round. The implementation must use +round-aware resampling or an equivalent persistent sampling stream so each +logical client's local data stream advances across rounds in a reproducible +way. + +## State Ownership + +Server-owned state: + +- the global model, +- outer optimizer momentum or other outer optimizer state, +- aggregation metadata needed to update the global model. + +Client-owned state: + +- inner optimizer state, such as AdamW first and second moments, +- scheduler state and global/local optimizer-step counters, +- sampler or dataloader stream position needed for small-`H` continuity. + +Client-owned optimizer and scheduler state must not appear in client-server +payloads. It must remain local to the logical client, including when training +uses subprocesses. + +## Parameter And Buffer Policy + +By default, the outer optimizer applies only to trainable floating parameters. +This matches the algorithm definition, which optimizes model parameters. + +Floating buffers, such as batch normalization running statistics, are +synchronized without outer momentum by default. They use the selected averaging +rule but do not receive server-side momentum or Nesterov treatment. + +Non-floating buffers use conservative FedAvg-style behavior, including casting +or rounding as needed to preserve the buffer's dtype-compatible semantics. + +The implementation may offer `apply_outer_optimizer_to = "all_floating"` for +experiments, but the default must remain `parameters`. + +## Configuration Contract + +The faithful initial mode uses these configuration names and defaults: + +```toml +[server] +type = "diloco" + +[algorithm] +type = "fedavg" + +[trainer] +local_steps_per_round = H +preserve_optimizer_state = true +optimizer = "AdamW" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" # or "num_samples" +apply_outer_optimizer_to = "parameters" # or "all_floating" +``` + +`algorithm.type = "fedavg"` is intentional. Plato should reuse the existing +FedAvg weight extraction, delta computation, and global model loading path, +while `server.type = "diloco"` selects the server-side DiLoCo aggregation and +outer optimizer behavior. + +`aggregation_weighting = "uniform"` matches the balanced worker setting most +closely. `aggregation_weighting = "num_samples"` matches Plato's traditional +sample-weighted FedAvg behavior. FedAvg equivalence for outer SGD with learning +rate `1.0` is valid only when both runs use the same weighting rule. + +Unsupported modes must fail clearly. They must not silently fall back to an +approximate DiLoCo variant. Examples include trainer backends that cannot count +local optimizer steps exactly, execution paths that cannot preserve +client-local optimizer and scheduler state, or payload paths that would send +optimizer state to the server. + +## Implementation Sequence + +Dependency graph: + +```text +D1 +|-- D2 --> D3 +|-- D4 --> D5 +|-- D6 --> D7 +|-- D8 --> D9 +`-- D10 --> D11 + +D3, D5, D7, D9, D11 --> D12 --> D13 +``` + +Tasks: + +```yaml +- id: D1 + depends_on: [] + task: Document the exact DiLoCo contract and unsupported modes. + +- id: D2 + depends_on: [D1] + task: Add red tests for server-side outer gradient sign, weighting, and + FedAvg equivalence under matching weighting. + +- id: D3 + depends_on: [D2] + task: Implement DiLoCo server aggregation and outer optimizer state for SGD, + momentum SGD, and Nesterov. + +- id: D4 + depends_on: [D1] + task: Add red tests for exact local optimizer-step counting and `H` smaller + than one epoch. + +- id: D5 + depends_on: [D4] + task: Implement `trainer.local_steps_per_round` with mid-epoch termination + after exactly `H` optimizer steps. + +- id: D6 + depends_on: [D1] + task: Add red tests for per-client optimizer and scheduler state + persistence. + +- id: D7 + depends_on: [D6] + task: Persist client-local optimizer and scheduler state without sending it + to the server. + +- id: D8 + depends_on: [D1] + task: Add red tests for round-aware small-`H` sampling. + +- id: D9 + depends_on: [D8] + task: Implement round-aware resampling or an equivalent persistent sampling + stream for each logical client. + +- id: D10 + depends_on: [D1] + task: Add red tests for parameter and buffer eligibility. + +- id: D11 + depends_on: [D10] + task: Implement the default trainable-parameter-only outer optimizer policy + and conservative buffer synchronization. + +- id: D12 + depends_on: [D3, D5, D7, D9, D11] + task: Wire exact DiLoCo configuration, examples, and user-facing + documentation. + +- id: D13 + depends_on: [D12] + task: Add end-to-end faithful-mode validation coverage. +``` + +Every implementation task should use red/green test-driven development. Add +the failing tests that describe the contract first, then implement the smallest +runtime change that makes those tests pass. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index fa5bc9908..f4177c8b4 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -88,7 +88,9 @@ nav: - Servers: references/servers.md - Trainers: references/trainers.md - Evaluators: references/evaluators.md - - Developer's Guide: development.md + - Developer's Guide: + - Overview: development.md + - DiLoCo Design Contract: development/diloco.md - Deployment Guide: deployment.md - Digital Research Alliance of Canada: ccdb.md - Miscellaneous Notes: misc.md From 0378a578fe59f8ff3805b90abd69ad6631be2b30 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:46:34 -0400 Subject: [PATCH 02/21] Implemented DiLoCo outer aggregation. Adds the DiLoCo aggregation strategy with server-side SGD, momentum SGD, and Nesterov outer optimizer behavior over Plato-style client deltas. Covers uniform and sample-weighted aggregation, validates configuration values, and adds focused tests for sign handling, FedAvg equivalence under matching weighting, momentum state persistence, reset, and stale-key cleanup.\n\nValidation reported by worker:\n- uv run pytest tests/servers/test_diloco_strategy.py\n- uv run pytest tests/servers/test_fedavg_strategy.py\n- uv run ruff check . --select I\n\nCovers Linear issue DT-410. --- .../strategies/aggregation/__init__.py | 2 + .../servers/strategies/aggregation/diloco.py | 315 ++++++++++++++++ tests/servers/test_diloco_strategy.py | 336 ++++++++++++++++++ 3 files changed, 653 insertions(+) create mode 100644 plato/servers/strategies/aggregation/diloco.py create mode 100644 tests/servers/test_diloco_strategy.py diff --git a/plato/servers/strategies/aggregation/__init__.py b/plato/servers/strategies/aggregation/__init__.py index f44c420a4..fe4174402 100644 --- a/plato/servers/strategies/aggregation/__init__.py +++ b/plato/servers/strategies/aggregation/__init__.py @@ -4,6 +4,7 @@ Each strategy is defined in its own module for clarity. """ +from plato.servers.strategies.aggregation.diloco import DiLoCoAggregationStrategy from plato.servers.strategies.aggregation.fedasync import FedAsyncAggregationStrategy from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy from plato.servers.strategies.aggregation.fedbuff import FedBuffAggregationStrategy @@ -16,6 +17,7 @@ __all__ = [ "FedAvgAggregationStrategy", + "DiLoCoAggregationStrategy", "FedBuffAggregationStrategy", "FedNovaAggregationStrategy", "FedAsyncAggregationStrategy", diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py new file mode 100644 index 000000000..2b2df57e3 --- /dev/null +++ b/plato/servers/strategies/aggregation/diloco.py @@ -0,0 +1,315 @@ +""" +DiLoCo aggregation strategy. + +The strategy consumes Plato-style client deltas (`client_after - global_before`), +converts them to DiLoCo outer gradients, and returns Plato-compatible server +deltas for `algorithm.update_weights()` to add to the global model. +""" + +from __future__ import annotations + +import asyncio +import copy +import numbers +from collections.abc import Callable, Mapping +from types import SimpleNamespace +from typing import Any, cast + +import numpy as np + +from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy +from plato.servers.strategies.base import ServerContext + +try: # pragma: no cover - optional dependency + import torch +except ImportError: # pragma: no cover + torch = cast(Any, None) + + +class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): + """Aggregate client deltas with a server-side DiLoCo outer optimizer.""" + + _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} + _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + + def __init__( + self, + outer_optimizer: str = "nesterov", + outer_learning_rate: float = 0.7, + outer_momentum: float = 0.9, + aggregation_weighting: str = "uniform", + ): + super().__init__() + self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) + self.outer_learning_rate = self._validate_learning_rate( + outer_learning_rate + ) + self.outer_momentum = self._validate_momentum(outer_momentum) + self.aggregation_weighting = self._validate_weighting_mode( + aggregation_weighting + ) + self.momentum_state: dict[str, Any] = {} + + async def aggregate_deltas( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + context: ServerContext, + ) -> dict: + """Aggregate deltas and apply the configured DiLoCo outer optimizer.""" + eligible = self._eligible_updates(updates, deltas_received) + if not eligible: + self._remove_stale_momentum(set()) + return self._empty_delta(context, self._first_delta(deltas_received)) + + weights = self._aggregation_weights(eligible) + if not weights: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta: Any = None + for (_, delta, _), weight in zip(eligible, weights): + avg_delta = self._accumulate_weighted(avg_delta, delta, weight, context) + await asyncio.sleep(0) + + if avg_delta is None: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) + outer_gradient = self._scale_tree(avg_delta, -1.0) + server_delta, active_paths = self._apply_outer_optimizer(outer_gradient) + self._remove_stale_momentum(active_paths) + + return self._match_reference_structure(server_delta, eligible[0][1]) + + @classmethod + def _validate_outer_optimizer(cls, value: str) -> str: + optimizer = str(value).lower() + if optimizer not in cls._SUPPORTED_OPTIMIZERS: + supported = ", ".join(sorted(cls._SUPPORTED_OPTIMIZERS)) + raise ValueError( + f"Invalid outer_optimizer '{value}'. Supported values: {supported}." + ) + return optimizer + + @staticmethod + def _validate_learning_rate(value: float) -> float: + learning_rate = float(value) + if learning_rate < 0: + raise ValueError("outer_learning_rate must be nonnegative.") + return learning_rate + + @staticmethod + def _validate_momentum(value: float) -> float: + momentum = float(value) + if not 0 <= momentum < 1: + raise ValueError("outer_momentum must be in the range [0, 1).") + return momentum + + @classmethod + def _validate_weighting_mode(cls, value: str) -> str: + weighting = str(value).lower() + if weighting not in cls._SUPPORTED_WEIGHTING_MODES: + supported = ", ".join(sorted(cls._SUPPORTED_WEIGHTING_MODES)) + raise ValueError( + "Invalid aggregation_weighting " + f"'{value}'. Supported values: {supported}." + ) + return weighting + + def _eligible_updates( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + ) -> list[tuple[SimpleNamespace, dict, float]]: + eligible: list[tuple[SimpleNamespace, dict, float]] = [] + for update, delta in zip(updates, deltas_received): + if getattr(update.report, "type", "weights") == "features": + continue + + num_samples = self._num_samples(update) + if num_samples <= 0: + continue + + eligible.append((update, delta, num_samples)) + + return eligible + + @staticmethod + def _num_samples(update: SimpleNamespace) -> float: + try: + return float(update.report.num_samples) + except (AttributeError, TypeError, ValueError): + return 0.0 + + def _aggregation_weights( + self, eligible: list[tuple[SimpleNamespace, dict, float]] + ) -> list[float]: + if not eligible: + return [] + + if self.aggregation_weighting == "uniform": + return [1.0 / len(eligible)] * len(eligible) + + total_samples = sum(num_samples for _, _, num_samples in eligible) + if total_samples <= 0: + return [] + + return [num_samples / total_samples for _, _, num_samples in eligible] + + def _apply_outer_optimizer(self, outer_gradient: Any) -> tuple[Any, set[str]]: + active_paths: set[str] = set() + + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate), set() + + server_delta = self._map_tree( + outer_gradient, + lambda value, path: self._apply_momentum_leaf( + value, path, active_paths + ), + ) + return server_delta, active_paths + + def _apply_momentum_leaf( + self, outer_gradient: Any, path: str, active_paths: set[str] + ) -> Any: + active_paths.add(path) + previous = self.momentum_state.get(path) + if previous is not None and not self._is_compatible(previous, outer_gradient): + previous = None + + if previous is None: + momentum = self._clone_tree(outer_gradient) + else: + momentum = self._add_values( + self._scale_tree(previous, self.outer_momentum), + outer_gradient, + ) + + self.momentum_state[path] = self._clone_tree(momentum) + + if self.outer_optimizer == "nesterov": + direction = self._add_values( + outer_gradient, + self._scale_tree(momentum, self.outer_momentum), + ) + else: + direction = momentum + + return self._scale_tree(direction, -self.outer_learning_rate) + + def _remove_stale_momentum(self, active_paths: set[str]) -> None: + if self.outer_optimizer == "sgd": + self.momentum_state.clear() + return + + for path in list(self.momentum_state): + if path not in active_paths: + del self.momentum_state[path] + + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: + zero_delta = self._zero_delta(context, reference_delta) + if zero_delta is not None: + return zero_delta + + if reference_delta is None: + return {} + + return self._scale_tree(reference_delta, 0.0) + + @staticmethod + def _first_delta(deltas_received: list[dict]) -> dict | None: + return deltas_received[0] if deltas_received else None + + def _map_tree(self, value: Any, leaf_fn: Callable[[Any, str], Any], path="") -> Any: + if isinstance(value, Mapping): + return { + key: self._map_tree(item, leaf_fn, self._join_path(path, key)) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ] + + if isinstance(value, tuple): + return tuple( + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ) + + return leaf_fn(value, path) + + def _scale_tree(self, value: Any, scalar: float) -> Any: + if isinstance(value, Mapping): + return { + key: self._scale_tree(item, scalar) for key, item in value.items() + } + + if isinstance(value, list): + return [self._scale_tree(item, scalar) for item in value] + + if isinstance(value, tuple): + return tuple(self._scale_tree(item, scalar) for item in value) + + return value * scalar + + @staticmethod + def _add_values(left: Any, right: Any) -> Any: + return left + right + + def _clone_tree(self, value: Any) -> Any: + if isinstance(value, Mapping): + return {key: self._clone_tree(item) for key, item in value.items()} + + if isinstance(value, list): + return [self._clone_tree(item) for item in value] + + if isinstance(value, tuple): + return tuple(self._clone_tree(item) for item in value) + + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().clone() + + if isinstance(value, np.ndarray): + return value.copy() + + try: + return copy.deepcopy(value) + except TypeError: + return value + + @staticmethod + def _is_compatible(left: Any, right: Any) -> bool: + if torch is not None and isinstance(left, torch.Tensor): + return ( + isinstance(right, torch.Tensor) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + if isinstance(left, np.ndarray): + return ( + isinstance(right, np.ndarray) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + left_shape = getattr(left, "shape", None) + right_shape = getattr(right, "shape", None) + if left_shape is not None or right_shape is not None: + return ( + left_shape == right_shape + and getattr(left, "dtype", None) == getattr(right, "dtype", None) + ) + + return isinstance(left, numbers.Number) and isinstance(right, numbers.Number) + + @staticmethod + def _join_path(prefix: str, key: Any) -> str: + key_text = str(key) + return key_text if not prefix else f"{prefix}.{key_text}" diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py new file mode 100644 index 000000000..336e5e1aa --- /dev/null +++ b/tests/servers/test_diloco_strategy.py @@ -0,0 +1,336 @@ +"""Tests for DiLoCo server-side outer aggregation.""" + +import asyncio +from types import SimpleNamespace + +import pytest +import torch + +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy +from plato.servers.strategies.base import ServerContext + + +class DummyAlgorithm: + """Minimal algorithm stub for zero-delta construction.""" + + def __init__(self, baseline): + self.baseline = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.baseline.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + return [ + { + name: weights[name] - baseline_weights[name] + for name in baseline_weights.keys() + } + for weights in weights_list + ] + + +def _context(baseline=None): + context = ServerContext() + if baseline is not None: + context.algorithm = DummyAlgorithm(baseline) + return context + + +def _update(num_samples, report_type="weights"): + return SimpleNamespace( + report=SimpleNamespace(num_samples=num_samples, type=report_type) + ) + + +def _aggregate(strategy, updates, deltas, baseline=None): + return asyncio.run( + strategy.aggregate_deltas(updates, deltas, _context(baseline)) + ) + + +def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): + """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(99)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([15.0])) + + +def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): + """Outer SGD with lr=1 should match sample-weighted FedAvg.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(3)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([6.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([16.5])) + + +def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): + """A lower outer SGD lr should partially move toward the averaged model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.5, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(5), _update(5)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([2.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([12.5])) + + +def test_sgd_uses_diloco_outer_gradient_sign(temp_config): + """The strategy should negate Plato deltas before applying outer SGD.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.25, + aggregation_weighting="uniform", + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([1.0])) + + +def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): + """Uniform mode should weight eligible clients equally.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + server_delta = _aggregate( + strategy, + [_update(1), _update(1000)], + [{"w": torch.tensor([0.0])}, {"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([5.0])) + + +def test_nonpositive_sample_reports_are_ineligible(temp_config): + """Reports with zero or negative sample counts should not affect averages.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + ) + + server_delta = _aggregate( + strategy, + [_update(0), _update(-5), _update(10)], + [ + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([4.0])}, + ], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0])) + + +def test_empty_eligible_updates_return_zero_delta(temp_config): + """An empty eligible set should produce a zero delta matching the baseline.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([3.0, 4.0])} + server_delta = _aggregate( + strategy, + [_update(0), _update(5, report_type="features")], + [{"w": torch.tensor([10.0, 10.0])}, {"w": torch.tensor([10.0, 10.0])}], + baseline, + ) + + assert torch.allclose(server_delta["w"], torch.zeros_like(baseline["w"])) + + +def test_empty_eligible_updates_remove_stale_momentum(temp_config): + """A round with no eligible keys should clear stale momentum buffers.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + server_delta = _aggregate( + strategy, + [_update(0)], + [{"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([0.0])) + assert strategy.momentum_state == {} + + +def test_sgdm_persists_momentum_across_rounds(temp_config): + """Momentum SGD should reuse server-side outer momentum across rounds.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([2.0])) + assert torch.allclose(second_delta["w"], torch.tensor([5.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): + """Nesterov should use g + beta * m after updating the momentum buffer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([3.0])) + assert torch.allclose(second_delta["w"], torch.tensor([6.5])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( + temp_config, +): + """Momentum state should reset incompatible keys and prune missing keys.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0]), "b": torch.tensor([1.0])}], + {"w": torch.tensor([0.0]), "b": torch.tensor([0.0])}, + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0, 6.0])}], + {"w": torch.tensor([0.0, 0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0, 6.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-4.0, -6.0])) + assert "b" not in strategy.momentum_state + + +def test_momentum_state_resets_on_dtype_mismatch(temp_config): + """Momentum state should reset when the tensor dtype changes.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0], dtype=torch.float32)}], + {"w": torch.tensor([0.0], dtype=torch.float32)}, + ) + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0], dtype=torch.float64)}], + {"w": torch.tensor([0.0], dtype=torch.float64)}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0], dtype=torch.float64)) + assert strategy.momentum_state["w"].dtype == torch.float64 + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"outer_optimizer": "adam"}, "outer_optimizer"), + ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"outer_learning_rate": -0.1}, "outer_learning_rate"), + ({"outer_momentum": -0.1}, "outer_momentum"), + ({"outer_momentum": 1.0}, "outer_momentum"), + ], +) +def test_invalid_config_values_fail_clearly(temp_config, kwargs, match): + """Invalid DiLoCo aggregation configuration should raise clear errors.""" + with pytest.raises(ValueError, match=match): + DiLoCoAggregationStrategy(**kwargs) From 6287b0f94287dd48c18aec525af9aea0dfec3c4c Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:47:22 -0400 Subject: [PATCH 03/21] Added exact local step limits for trainers. Adds trainer.local_steps_per_round support to ComposableTrainer so local work can stop after an exact number of completed optimizer steps, including mid-epoch DiLoCo-style runs. The trainer counts optimizer steps rather than raw batches, avoids finalization after the limit is reached, preserves existing epoch behavior when unset, and adds focused tests for delayed optimizer stepping, cleanup, and invalid values.\n\nValidation reported by worker:\n- uv run pytest tests/trainers/test_composable_trainer.py -k local_steps\n- uv run pytest tests/trainers/test_composable_trainer.py\n- uv run ruff check . --select I\n\nThe broader ============================= test session starts ============================== platform darwin -- Python 3.13.12, pytest-8.4.2, pluggy-1.6.0 rootdir: /Users/bli/Playground/plato configfile: pyproject.toml plugins: anyio-4.13.0 collected 106 items / 1 error / 75 deselected / 31 selected ==================================== ERRORS ==================================== _______ ERROR collecting tests/trainers/test_dp_data_loader_strategy.py ________ ImportError while importing test module '/Users/bli/Playground/plato/tests/trainers/test_dp_data_loader_strategy.py'. Hint: make sure your test modules/packages have valid Python names. Traceback: ../../.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/importlib/__init__.py:88: in import_module return _bootstrap._gcd_import(name[level:], package, level) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tests/trainers/test_dp_data_loader_strategy.py:4: in from plato.trainers.diff_privacy import DPDataLoaderStrategy plato/trainers/diff_privacy.py:15: in from opacus import GradSampleModule E ModuleNotFoundError: No module named 'opacus' =========================== short test summary info ============================ ERROR tests/trainers/test_dp_data_loader_strategy.py !!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!! ======================= 75 deselected, 1 error in 0.11s ======================== collection is blocked by missing optional dependency opacus in unrelated DP tests.\n\nCovers Linear issue DT-416. --- plato/trainers/composable.py | 47 ++++- tests/trainers/test_composable_trainer.py | 206 ++++++++++++++++++++++ 2 files changed, 251 insertions(+), 2 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 1e98d1128..f2bc2f0ca 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -177,6 +177,29 @@ def _require_model(self) -> nn.Module: ) return cast(nn.Module, self.model) + @staticmethod + def _local_steps_per_round(config: dict[str, Any]) -> int | None: + """Return the optional local optimizer-step limit for one train run.""" + value = config.get("local_steps_per_round") + if value is None: + return None + + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + raise ValueError( + "trainer.local_steps_per_round must be a positive integer." + ) + + return value + + def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> bool: + """Record one completed optimizer step and report whether H was reached.""" + if local_steps_per_round is None: + return False + + completed_steps = int(self.context.state.get("local_optimizer_steps", 0)) + 1 + self.context.state["local_optimizer_steps"] = completed_steps + return completed_steps >= local_steps_per_round + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -397,6 +420,12 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + local_steps_per_round = self._local_steps_per_round(config) + self.context.state["local_optimizer_steps"] = 0 + if local_steps_per_round is None: + self.context.state.pop("local_steps_per_round", None) + else: + self.context.state["local_steps_per_round"] = local_steps_per_round # Ensure training step strategy respects higher-order gradient settings if self.training_step_strategy is not None: @@ -494,6 +523,7 @@ def train_model(self, config, trainset, sampler, **kwargs): total_epochs = config["epochs"] tic = time.perf_counter() training_stop_requested = False + local_step_limit_reached = False try: total_batches = len(self.train_loader) except (TypeError, AttributeError): @@ -564,6 +594,9 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) # Strategy hook: after_step self.model_update_strategy.after_step(self.context) @@ -591,7 +624,7 @@ def compute_loss(outputs, labels_inner): ): self._handle_control_log() - if control_actions.get("stop_training"): + if control_actions.get("stop_training") or local_step_limit_reached: training_stop_requested = True break @@ -601,7 +634,11 @@ def compute_loss(outputs, labels_inner): finalize_loss = None finalize_step_done = False finalize_callable = getattr(self.training_step_strategy, "finalize", None) - if batches_seen and callable(finalize_callable): + if ( + batches_seen + and callable(finalize_callable) + and not local_step_limit_reached + ): finalize_loss = finalize_callable( model=model, optimizer=self.optimizer, @@ -613,6 +650,9 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) self.model_update_strategy.after_step(self.context) self.callback_handler.call_event( "on_train_step_end", @@ -652,6 +692,9 @@ def compute_loss(outputs, labels_inner): # No batches remain, but respect control flag. pass + if local_step_limit_reached: + training_stop_requested = True + self.context.state.pop("is_last_batch", None) self.context.state.pop("hf_optimizer_step_index", None) diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ff35f76ec..ec83e20ea 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.utils.data import TensorDataset +from plato.callbacks.trainer import TrainerCallback from plato.config import Config from plato.evaluators.runner import EVALUATION_PRIMARY_KEY, EVALUATION_RESULTS_KEY from plato.trainers.composable import ComposableTrainer @@ -25,6 +26,7 @@ LossCriterionStrategy, ModelUpdateStrategy, TrainingContext, + TrainingStepStrategy, ) @@ -178,6 +180,210 @@ def test_multiple_epochs(self, simple_model, simple_dataset): assert len(trainer.run_history.get_metric_values("train_loss")) == 5 +class TestComposableTrainerLocalSteps: + """Test local optimizer-step limits for DiLoCo-style local work.""" + + class CountingCallback(TrainerCallback): + def __init__(self): + self.train_run_end_called = False + self.train_step_end_count = 0 + + def on_train_step_end(self, trainer, config, batch, loss, **kwargs): + self.train_step_end_count += 1 + + def on_train_run_end(self, trainer, config, **kwargs): + self.train_run_end_called = True + + class CountingUpdateStrategy(ModelUpdateStrategy): + def __init__(self): + self.after_step_count = 0 + self.on_train_end_called = False + + def after_step(self, context): + self.after_step_count += 1 + + def on_train_end(self, context): + self.on_train_end_called = True + + class CountingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class DelayedOptimizerStepStrategy(TrainingStepStrategy): + def __init__(self, accumulation_steps=2, finalize_steps=True): + self.accumulation_steps = accumulation_steps + self.finalize_steps = finalize_steps + self.raw_batch_count = 0 + self.optimizer_step_count = 0 + self.finalize_calls = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + outputs = model(examples) + loss = loss_criterion(outputs, labels) + (loss / self.accumulation_steps).backward() + + self.raw_batch_count += 1 + if self.raw_batch_count % self.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False + + return loss + + def finalize(self, model, optimizer, context): + self.finalize_calls += 1 + if not self.finalize_steps: + return None + + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + return torch.tensor(0.0) + + def test_local_steps_stop_mid_epoch_and_run_cleanup( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + callback = self.CountingCallback() + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + callbacks=[callback], + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 3 + assert update_strategy.after_step_count == 3 + assert callback.train_step_end_count == 3 + assert trainer.current_epoch == 1 + assert trainer.context.state["local_optimizer_steps"] == 3 + assert update_strategy.on_train_end_called + assert callback.train_run_end_called + + def test_local_steps_count_optimizer_steps_not_raw_batches( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=3) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert step_strategy.optimizer_step_count == 2 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_skip_finalize_after_limit_is_reached( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 2, + "local_steps_per_round": 1, + } + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=2) + update_strategy = self.CountingUpdateStrategy() + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 2 + assert step_strategy.optimizer_step_count == 1 + assert step_strategy.finalize_calls == 0 + assert update_strategy.after_step_count == 1 + assert trainer.context.state["local_optimizer_steps"] == 1 + + def test_epoch_behavior_is_unchanged_when_local_steps_unset( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 10, + "epochs": 2, + } + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 20 + assert trainer.current_epoch == 2 + assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) + def test_invalid_local_steps_fail_clearly( + self, simple_model, simple_dataset, simple_config, local_steps_per_round + ): + config = { + **simple_config, + "local_steps_per_round": local_steps_per_round, + } + trainer = ComposableTrainer(model=simple_model) + + with pytest.raises(ValueError, match="local_steps_per_round"): + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + class TestComposableTrainerStrategies: """Test strategy integration.""" From b91d97d67cf4958de9a72e02b45f7f351ecd92cf Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:52:00 -0400 Subject: [PATCH 04/21] Fixed local step counting for accumulation. DT-417 review found that the built-in GradientAccumulationStepStrategy did not publish optimizer_step_completed, so local_steps_per_round counted raw batches instead of optimizer steps when accumulation_steps > 1. Set optimizer_step_completed only when the accumulation strategy actually performs optimizer.step(), and add a regression test that uses the real built-in accumulation strategy with H=2 and accumulation_steps=3. Validation: uv run pytest tests/trainers/test_composable_trainer.py -k local_steps; uv run pytest tests/trainers/test_composable_trainer.py; uv run ruff check . --select I. --- plato/trainers/strategies/training_step.py | 3 ++ tests/trainers/test_composable_trainer.py | 50 ++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/plato/trainers/strategies/training_step.py b/plato/trainers/strategies/training_step.py index b4aba6d9d..5afa9702c 100644 --- a/plato/trainers/strategies/training_step.py +++ b/plato/trainers/strategies/training_step.py @@ -128,6 +128,9 @@ def training_step( if self.current_step % self.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False # Return unscaled loss for logging return loss diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ec83e20ea..d5daeb4b4 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -19,6 +19,7 @@ CrossEntropyLossStrategy, DefaultDataLoaderStrategy, DefaultTrainingStepStrategy, + GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, ) @@ -272,6 +273,30 @@ def finalize(self, model, optimizer, context): context.state["optimizer_step_completed"] = True return torch.tensor(0.0) + class CountingGradientAccumulationStepStrategy(GradientAccumulationStepStrategy): + def __init__(self, accumulation_steps): + super().__init__(accumulation_steps=accumulation_steps) + self.raw_batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.raw_batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + def test_local_steps_stop_mid_epoch_and_run_cleanup( self, simple_model, simple_dataset, simple_config ): @@ -325,6 +350,31 @@ def test_local_steps_count_optimizer_steps_not_raw_batches( assert update_strategy.after_step_count == 2 assert trainer.context.state["local_optimizer_steps"] == 2 + def test_local_steps_respect_builtin_gradient_accumulation( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingGradientAccumulationStepStrategy( + accumulation_steps=3 + ) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + def test_local_steps_skip_finalize_after_limit_is_reached( self, simple_model, simple_dataset, simple_config ): From 6846f5097dffa8c01825189fee51163dc7036017 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:56:05 -0400 Subject: [PATCH 05/21] Added DiLoCo parameter eligibility policy. DT-412 makes the DiLoCo outer optimizer apply to trainable floating parameters by default while preserving full state_dict safety for frozen parameters and buffers. The new apply_outer_optimizer_to option supports parameters and all_floating modes, validates unsupported values clearly, resolves trainable parameter names from context.trainer.model for the default mode, and keeps momentum state only for tensors that receive outer optimization. Tests cover trainable parameters, frozen parameters, floating buffers, integer and boolean buffers, all_floating behavior, missing model context, invalid config values, and the existing DiLoCo aggregation math. Validation: uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 142 ++++++++++++- tests/servers/test_diloco_strategy.py | 191 +++++++++++++++++- 2 files changed, 321 insertions(+), 12 deletions(-) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 2b2df57e3..a00316a9a 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -31,6 +31,7 @@ class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + _SUPPORTED_APPLY_POLICIES = {"parameters", "all_floating"} def __init__( self, @@ -38,6 +39,7 @@ def __init__( outer_learning_rate: float = 0.7, outer_momentum: float = 0.9, aggregation_weighting: str = "uniform", + apply_outer_optimizer_to: str = "parameters", ): super().__init__() self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) @@ -48,6 +50,9 @@ def __init__( self.aggregation_weighting = self._validate_weighting_mode( aggregation_weighting ) + self.apply_outer_optimizer_to = self._validate_apply_policy( + apply_outer_optimizer_to + ) self.momentum_state: dict[str, Any] = {} async def aggregate_deltas( @@ -77,8 +82,10 @@ async def aggregate_deltas( return self._empty_delta(context, eligible[0][1]) avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) - outer_gradient = self._scale_tree(avg_delta, -1.0) - server_delta, active_paths = self._apply_outer_optimizer(outer_gradient) + optimizer_paths = self._outer_optimizer_paths(avg_delta, context) + server_delta, active_paths = self._apply_outer_optimizer( + avg_delta, optimizer_paths + ) self._remove_stale_momentum(active_paths) return self._match_reference_structure(server_delta, eligible[0][1]) @@ -118,6 +125,17 @@ def _validate_weighting_mode(cls, value: str) -> str: ) return weighting + @classmethod + def _validate_apply_policy(cls, value: str) -> str: + policy = str(value).lower() + if policy not in cls._SUPPORTED_APPLY_POLICIES: + supported = ", ".join(sorted(cls._SUPPORTED_APPLY_POLICIES)) + raise ValueError( + "Invalid apply_outer_optimizer_to " + f"'{value}'. Supported values: {supported}." + ) + return policy + def _eligible_updates( self, updates: list[SimpleNamespace], @@ -158,20 +176,48 @@ def _aggregation_weights( return [num_samples / total_samples for _, _, num_samples in eligible] - def _apply_outer_optimizer(self, outer_gradient: Any) -> tuple[Any, set[str]]: - active_paths: set[str] = set() + def _outer_optimizer_paths( + self, avg_delta: Any, context: ServerContext + ) -> set[str]: + if self.apply_outer_optimizer_to == "all_floating": + return self._floating_leaf_paths(avg_delta) + + trainable_parameter_names = self._trainable_parameter_names(context) + return self._collect_leaf_paths( + avg_delta, + lambda value, path: path in trainable_parameter_names + and self._is_floating_value(value), + ) - if self.outer_optimizer == "sgd": - return self._scale_tree(outer_gradient, -self.outer_learning_rate), set() + def _apply_outer_optimizer( + self, avg_delta: Any, optimizer_paths: set[str] + ) -> tuple[Any, set[str]]: + active_paths: set[str] = set() server_delta = self._map_tree( - outer_gradient, - lambda value, path: self._apply_momentum_leaf( - value, path, active_paths + avg_delta, + lambda value, path: self._apply_outer_optimizer_leaf( + value, path, optimizer_paths, active_paths ), ) return server_delta, active_paths + def _apply_outer_optimizer_leaf( + self, + avg_delta: Any, + path: str, + optimizer_paths: set[str], + active_paths: set[str], + ) -> Any: + if path not in optimizer_paths: + return avg_delta + + outer_gradient = self._scale_tree(avg_delta, -1.0) + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate) + + return self._apply_momentum_leaf(outer_gradient, path, active_paths) + def _apply_momentum_leaf( self, outer_gradient: Any, path: str, active_paths: set[str] ) -> Any: @@ -209,6 +255,84 @@ def _remove_stale_momentum(self, active_paths: set[str]) -> None: if path not in active_paths: del self.momentum_state[path] + def _trainable_parameter_names(self, context: ServerContext) -> set[str]: + model = self._model_from_context(context) + trainable_names: set[str] = set() + + for name, parameter in model.named_parameters(): + if getattr(parameter, "requires_grad", False) and self._is_floating_value( + parameter + ): + trainable_names.add(name) + + return trainable_names + + @staticmethod + def _model_from_context(context: ServerContext) -> Any: + trainer = getattr(context, "trainer", None) + model = getattr(trainer, "model", None) if trainer is not None else None + if model is None or not hasattr(model, "named_parameters"): + raise AttributeError( + "DiLoCo apply_outer_optimizer_to='parameters' requires " + "context.trainer.model with named_parameters()." + ) + return model + + def _floating_leaf_paths(self, value: Any) -> set[str]: + return self._collect_leaf_paths( + value, lambda leaf, _: self._is_floating_value(leaf) + ) + + def _collect_leaf_paths( + self, + value: Any, + predicate: Callable[[Any, str], bool], + path: str = "", + ) -> set[str]: + if isinstance(value, Mapping): + paths: set[str] = set() + for key, item in value.items(): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, key) + ) + ) + return paths + + if isinstance(value, list): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + if isinstance(value, tuple): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + return {path} if predicate(value, path) else set() + + @staticmethod + def _is_floating_value(value: Any) -> bool: + if torch is not None and isinstance(value, torch.Tensor): + return torch.is_floating_point(value) + + if isinstance(value, np.ndarray): + return np.issubdtype(value.dtype, np.floating) + + return isinstance(value, numbers.Real) and not isinstance( + value, (numbers.Integral, bool) + ) + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: zero_delta = self._zero_delta(context, reference_delta) if zero_delta is not None: diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 336e5e1aa..a7b60b78d 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -35,10 +35,24 @@ def compute_weight_deltas(self, baseline_weights, weights_list): ] -def _context(baseline=None): +class MixedStateModel(torch.nn.Module): + """Model exposing trainable, frozen, floating-buffer, and integer state.""" + + def __init__(self): + super().__init__() + self.trainable = torch.nn.Parameter(torch.tensor([1.0])) + self.frozen = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=False) + self.register_buffer("floating_buffer", torch.tensor([1.0])) + self.register_buffer("integer_buffer", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) + + +def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: context.algorithm = DummyAlgorithm(baseline) + if model is not None: + context.trainer = SimpleNamespace(model=model) return context @@ -48,9 +62,9 @@ def _update(num_samples, report_type="weights"): ) -def _aggregate(strategy, updates, deltas, baseline=None): +def _aggregate(strategy, updates, deltas, baseline=None, model=None): return asyncio.run( - strategy.aggregate_deltas(updates, deltas, _context(baseline)) + strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) ) @@ -60,6 +74,7 @@ def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -77,6 +92,7 @@ def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -95,6 +111,7 @@ def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): outer_optimizer="sgd", outer_learning_rate=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -113,6 +130,7 @@ def test_sgd_uses_diloco_outer_gradient_sign(temp_config): outer_optimizer="sgd", outer_learning_rate=0.25, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -131,6 +149,7 @@ def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -149,6 +168,7 @@ def test_nonpositive_sample_reports_are_ineligible(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -171,6 +191,7 @@ def test_empty_eligible_updates_return_zero_delta(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([3.0, 4.0])} @@ -191,6 +212,7 @@ def test_empty_eligible_updates_remove_stale_momentum(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -217,6 +239,7 @@ def test_sgdm_persists_momentum_across_rounds(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) first_delta = _aggregate( @@ -244,6 +267,7 @@ def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) first_delta = _aggregate( @@ -273,6 +297,7 @@ def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -301,6 +326,7 @@ def test_momentum_state_resets_on_dtype_mismatch(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -320,11 +346,170 @@ def test_momentum_state_resets_on_dtype_mismatch(temp_config): assert strategy.momentum_state["w"].dtype == torch.float64 +def test_parameters_policy_optimizes_only_trainable_floating_parameters( + temp_config, +): + """Default policy should leave frozen parameters and buffers on FedAvg deltas.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + first_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + model, + ) + + assert torch.allclose(first_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(first_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(first_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(first_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(first_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-4.0])) + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + model, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([6.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([6.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-8.0])) + + +def test_all_floating_policy_optimizes_every_floating_state_tensor(temp_config): + """All-floating mode should not require model context for eligibility.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + server_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + ) + + assert torch.allclose(server_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(server_delta["frozen"], torch.tensor([2.0])) + assert torch.allclose(server_delta["floating_buffer"], torch.tensor([2.0])) + assert torch.equal(server_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(server_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + +def test_parameters_policy_requires_trainer_model_context(temp_config): + """Default parameter eligibility should fail clearly without a model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + with pytest.raises(AttributeError, match="context.trainer.model"): + _aggregate( + strategy, + [_update(1)], + [{"trainable": torch.tensor([2.0])}], + {"trainable": torch.tensor([0.0])}, + ) + + @pytest.mark.parametrize( ("kwargs", "match"), [ ({"outer_optimizer": "adam"}, "outer_optimizer"), ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"apply_outer_optimizer_to": "buffers"}, "apply_outer_optimizer_to"), ({"outer_learning_rate": -0.1}, "outer_learning_rate"), ({"outer_momentum": -0.1}, "outer_momentum"), ({"outer_momentum": 1.0}, "outer_momentum"), From d122831ecb0471b6acbfddfb5975f70309a6586d Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:02:37 -0400 Subject: [PATCH 06/21] Handled adapter payload names in DiLoCo eligibility. DT-413 review found that default parameter eligibility missed trainable adapter parameters when adapter payload keys omit PEFT adapter-name segments, such as lora_A.weight versus lora_A.default.weight. Resolve trainable payload aliases from model adapter metadata and intersect them with the actual floating payload leaves, so exact state_dict keys still work while PEFT-style adapter payloads receive outer optimization and momentum. Added a PEFT-like regression test that fails without the alias mapping and verifies the payload key receives SGDM scaling and momentum state. Validation: uv run pytest tests/servers/test_diloco_strategy.py -k peft_adapter -q; uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 60 ++++++++++++++++--- tests/servers/test_diloco_strategy.py | 41 +++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index a00316a9a..18162fd14 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -182,12 +182,11 @@ def _outer_optimizer_paths( if self.apply_outer_optimizer_to == "all_floating": return self._floating_leaf_paths(avg_delta) - trainable_parameter_names = self._trainable_parameter_names(context) - return self._collect_leaf_paths( - avg_delta, - lambda value, path: path in trainable_parameter_names - and self._is_floating_value(value), + floating_paths = self._floating_leaf_paths(avg_delta) + trainable_parameter_names = self._trainable_parameter_names( + context, floating_paths ) + return floating_paths.intersection(trainable_parameter_names) def _apply_outer_optimizer( self, avg_delta: Any, optimizer_paths: set[str] @@ -255,18 +254,65 @@ def _remove_stale_momentum(self, active_paths: set[str]) -> None: if path not in active_paths: del self.momentum_state[path] - def _trainable_parameter_names(self, context: ServerContext) -> set[str]: + def _trainable_parameter_names( + self, context: ServerContext, payload_paths: set[str] | None = None + ) -> set[str]: model = self._model_from_context(context) + adapter_names = self._adapter_names(model) trainable_names: set[str] = set() for name, parameter in model.named_parameters(): if getattr(parameter, "requires_grad", False) and self._is_floating_value( parameter ): - trainable_names.add(name) + trainable_names.update( + self._payload_name_candidates(name, adapter_names, payload_paths) + ) return trainable_names + @staticmethod + def _adapter_names(model: Any) -> set[str]: + adapter_names = {"default"} + + peft_config = getattr(model, "peft_config", None) + if isinstance(peft_config, Mapping): + adapter_names.update(str(name) for name in peft_config) + + active_adapter = getattr(model, "active_adapter", None) + if isinstance(active_adapter, str): + adapter_names.add(active_adapter) + + active_adapters = getattr(model, "active_adapters", None) + if callable(active_adapters): + try: + adapter_names.update(str(name) for name in active_adapters()) + except TypeError: + pass + elif isinstance(active_adapters, (list, tuple, set)): + adapter_names.update(str(name) for name in active_adapters) + + return adapter_names + + @classmethod + def _payload_name_candidates( + cls, + parameter_name: str, + adapter_names: set[str], + payload_paths: set[str] | None, + ) -> set[str]: + candidates = {parameter_name} + parts = parameter_name.split(".") + for index, part in enumerate(parts): + if part not in adapter_names: + continue + + candidate = ".".join(parts[:index] + parts[index + 1 :]) + if payload_paths is None or candidate in payload_paths: + candidates.add(candidate) + + return candidates + @staticmethod def _model_from_context(context: ServerContext) -> Any: trainer = getattr(context, "trainer", None) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index a7b60b78d..772ed5b0a 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -47,6 +47,20 @@ def __init__(self): self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) +class PeftLikeAdapterModel(torch.nn.Module): + """Model whose adapter payload keys omit PEFT's default adapter segment.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.base_model = torch.nn.Module() + self.base_model.model = torch.nn.Module() + self.base_model.model.linear = torch.nn.Module() + self.base_model.model.linear.lora_A = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: @@ -504,6 +518,33 @@ def test_parameters_policy_requires_trainer_model_context(temp_config): ) +def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): + """PEFT payloads can omit adapter-name segments from trainable param names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = PeftLikeAdapterModel() + payload_name = "base_model.model.linear.lora_A.weight" + baseline = {payload_name: torch.zeros((1, 1))} + + server_delta = _aggregate( + strategy, + [_update(1)], + [{payload_name: torch.full((1, 1), 4.0)}], + baseline, + model, + ) + + assert torch.allclose(server_delta[payload_name], torch.full((1, 1), 2.0)) + assert set(strategy.momentum_state) == {payload_name} + assert torch.allclose( + strategy.momentum_state[payload_name], torch.full((1, 1), -4.0) + ) + + @pytest.mark.parametrize( ("kwargs", "match"), [ From a9d8f3b9bb518e82fe22efd30cebabb3f5ce0ee3 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:11:11 -0400 Subject: [PATCH 07/21] Avoided DiLoCo adapter alias overmatching. DT-413 re-review found that adapter-name aliasing could include a separate floating payload key when the exact trainable parameter key was also present. Make exact payload key matches take precedence over adapter-name removal, so alias candidates are only considered when the original trainable parameter name is absent from the payload. Added a negative collision regression to keep unrelated payload keys on the plain averaged-delta path. Validation: uv run pytest tests/servers/test_diloco_strategy.py -k "adapter_payload_names or alias_collisions" -q; uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 3 ++ tests/servers/test_diloco_strategy.py | 45 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 18162fd14..6504990cf 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -302,6 +302,9 @@ def _payload_name_candidates( payload_paths: set[str] | None, ) -> set[str]: candidates = {parameter_name} + if payload_paths is not None and parameter_name in payload_paths: + return candidates + parts = parameter_name.split(".") for index, part in enumerate(parts): if part not in adapter_names: diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 772ed5b0a..196928af8 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -61,6 +61,17 @@ def __init__(self): ) +class AdapterAliasCollisionModel(torch.nn.Module): + """Model with a trainable parameter and separate payload key collision.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.foo = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: @@ -545,6 +556,40 @@ def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): ) +def test_parameters_policy_does_not_overmatch_adapter_alias_collisions(temp_config): + """Alias support should not optimize unrelated colliding payload names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = AdapterAliasCollisionModel() + trainable_name = "foo.default.weight" + colliding_name = "foo.weight" + baseline = { + trainable_name: torch.zeros((1, 1)), + colliding_name: torch.zeros((1, 1)), + } + + server_delta = _aggregate( + strategy, + [_update(1)], + [ + { + trainable_name: torch.full((1, 1), 4.0), + colliding_name: torch.full((1, 1), 4.0), + } + ], + baseline, + model, + ) + + assert torch.allclose(server_delta[trainable_name], torch.full((1, 1), 2.0)) + assert torch.allclose(server_delta[colliding_name], torch.full((1, 1), 4.0)) + assert set(strategy.momentum_state) == {trainable_name} + + @pytest.mark.parametrize( ("kwargs", "match"), [ From 679c9f6ba7ccac84af03feea0a0deebebbc5b38b Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:13:31 -0400 Subject: [PATCH 08/21] Persisted in-process optimizer state for DiLoCo. DT-418 adds trainer.preserve_optimizer_state for the PyTorch ComposableTrainer in-process path so client-local AdamW and scheduler state survive communication rounds without entering client-server payloads. The trainer caches optimizer and scheduler state per logical client, restores it after creating the next round optimizer/scheduler, and discards cached state when optimizer type, scheduler type, parameter names, shapes, dtypes, or optimizer parameter ordering no longer match. Focused tests cover AdamW moment persistence, scheduler LR progress, logical-client isolation, payload locality, disabled behavior, optimizer changes, parameter-order changes, and shape/dtype/scheduler compatibility rejection. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k optimizer_state-or-scheduler_state-or-composable -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I. The unignored trainer selector still hits the repo optional opacus collection dependency in tests/trainers/test_dp_data_loader_strategy.py. --- plato/trainers/composable.py | 147 +++++++++ .../test_composable_optimizer_state.py | 304 ++++++++++++++++++ 2 files changed, 451 insertions(+) create mode 100644 tests/trainers/test_composable_optimizer_state.py diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index f2bc2f0ca..8f16f0656 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -168,6 +168,7 @@ def __init__( self.current_epoch = 0 self.training_start_time = time.time() self.model_state_dict = None + self._preserved_optimizer_states: dict[int, dict[str, Any]] = {} def _require_model(self) -> nn.Module: """Return the underlying model, ensuring it is available.""" @@ -200,6 +201,143 @@ def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> boo self.context.state["local_optimizer_steps"] = completed_steps return completed_steps >= local_steps_per_round + @staticmethod + def _preserve_optimizer_state(config: dict[str, Any]) -> bool: + """Return whether optimizer state should survive local train runs.""" + return bool(config.get("preserve_optimizer_state", False)) + + @staticmethod + def _parameter_signature(name: str | None, parameter: torch.Tensor): + """Build a compatibility signature for one model parameter.""" + return (name, tuple(parameter.shape), str(parameter.dtype)) + + @classmethod + def _model_parameter_signature(cls, model: nn.Module): + """Return parameter names, shapes, dtypes, and order for a model.""" + return tuple( + cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + ) + + @classmethod + def _optimizer_parameter_signature( + cls, model: nn.Module, optimizer: torch.optim.Optimizer + ): + """Return optimizer parameter group ordering with model metadata.""" + named_parameters = { + id(parameter): cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + } + + group_signatures = [] + for group in optimizer.param_groups: + group_signatures.append( + tuple( + named_parameters.get( + id(parameter), + cls._parameter_signature(None, parameter), + ) + for parameter in group.get("params", []) + ) + ) + + return tuple(group_signatures) + + @staticmethod + def _scheduler_type(scheduler: Any | None) -> type | None: + """Return the scheduler type used for compatibility checks.""" + if scheduler is None: + return None + return type(scheduler) + + def _preserved_state_is_compatible( + self, + payload: dict[str, Any], + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Any | None, + ) -> bool: + """Return whether a cached optimizer bundle matches this train run.""" + if payload.get("optimizer_type") is not type(optimizer): + return False + + if payload.get("scheduler_type") is not self._scheduler_type(scheduler): + return False + + if payload.get("model_parameters") != self._model_parameter_signature(model): + return False + + if payload.get("optimizer_parameters") != self._optimizer_parameter_signature( + model, optimizer + ): + return False + + if not callable(getattr(optimizer, "load_state_dict", None)): + return False + + if payload.get("scheduler_state") is not None and not callable( + getattr(scheduler, "load_state_dict", None) + ): + return False + + return True + + def _restore_preserved_optimizer_state(self) -> None: + """Restore compatible optimizer and scheduler state for this client.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None or self.optimizer is None: + return + + model = self._require_model() + if not self._preserved_state_is_compatible( + payload, model, self.optimizer, self.lr_scheduler + ): + self._preserved_optimizer_states.pop(self.client_id, None) + return + + try: + scheduler_state = payload.get("scheduler_state") + if scheduler_state is not None: + self.lr_scheduler.load_state_dict(copy.deepcopy(scheduler_state)) + + self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) + except Exception as error: + logging.debug( + "[Client #%d] Discarding incompatible optimizer state: %s", + self.client_id, + error, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + self.optimizer = self.optimizer_strategy.create_optimizer( + model, self.context + ) + self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( + self.optimizer, self.context + ) + + def _save_preserved_optimizer_state(self) -> None: + """Save optimizer and scheduler state locally for this logical client.""" + if self.optimizer is None: + return + + model = self._require_model() + scheduler_state = None + if self.lr_scheduler is not None: + state_dict_fn = getattr(self.lr_scheduler, "state_dict", None) + if callable(state_dict_fn): + scheduler_state = copy.deepcopy(state_dict_fn()) + + self._preserved_optimizer_states[self.client_id] = { + "optimizer_type": type(self.optimizer), + "optimizer_state": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_type": self._scheduler_type(self.lr_scheduler), + "scheduler_state": scheduler_state, + "model_parameters": self._model_parameter_signature(model), + "optimizer_parameters": self._optimizer_parameter_signature( + model, self.optimizer + ), + } + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -420,6 +558,10 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + preserve_optimizer_state = self._preserve_optimizer_state(config) + if not preserve_optimizer_state: + self._preserved_optimizer_states.pop(self.client_id, None) + local_steps_per_round = self._local_steps_per_round(config) self.context.state["local_optimizer_steps"] = 0 if local_steps_per_round is None: @@ -513,6 +655,8 @@ def train_model(self, config, trainset, sampler, **kwargs): self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( self.optimizer, self.context ) + if preserve_optimizer_state: + self._restore_preserved_optimizer_state() # Move model to device model = self._require_model() @@ -744,6 +888,9 @@ def compute_loss(outputs, labels_inner): # Callbacks: train run end self.callback_handler.call_event("on_train_run_end", self, config) + if preserve_optimizer_state: + self._save_preserved_optimizer_state() + def train(self, trainset, sampler, **kwargs) -> float: """ The main training loop in a federated learning workload. diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py new file mode 100644 index 000000000..f37289625 --- /dev/null +++ b/tests/trainers/test_composable_optimizer_state.py @@ -0,0 +1,304 @@ +"""Tests for in-process optimizer state preservation in ComposableTrainer.""" + +import copy +from collections import OrderedDict + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import ( + AdamWOptimizerStrategy, + CrossEntropyLossStrategy, + DefaultTrainingStepStrategy, + SGDOptimizerStrategy, + StepLRSchedulerStrategy, +) + + +@pytest.fixture +def tiny_dataset(): + features = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [-1.0, 0.5], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 1, 0, 1], dtype=torch.long) + return TensorDataset(features, labels) + + +@pytest.fixture +def one_step_config(): + return { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "optimizer-state-test", + } + + +class CapturingTrainingStep(DefaultTrainingStepStrategy): + """Record optimizer state before each local optimizer step.""" + + def __init__(self): + super().__init__() + self.pre_step_states = [] + self.pre_step_lrs = [] + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + optimizer_state = optimizer.state_dict() + self.pre_step_states.append(copy.deepcopy(optimizer_state["state"])) + self.pre_step_lrs.append( + [group["lr"] for group in optimizer_state["param_groups"]] + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + +def _linear_model(): + return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) + + +def _two_layer_model(first_name="first", second_name="second"): + return nn.Sequential( + OrderedDict( + [ + (first_name, nn.Linear(2, 2, bias=False)), + (second_name, nn.Linear(2, 2, bias=False)), + ] + ) + ) + + +def _first_param_state(optimizer_state): + return next(iter(optimizer_state.values())) + + +def _state_step(param_state): + step = param_state["step"] + if isinstance(step, torch.Tensor): + return int(step.item()) + return int(step) + + +def test_adamw_moment_buffers_persist_between_rounds_for_same_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + trainer.set_client_id(7) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + round1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[1]) + saved_state = _first_param_state(round1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + assert torch.allclose(restored_state["exp_avg_sq"], saved_state["exp_avg_sq"]) + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 2 + + +def test_scheduler_state_and_lr_progress_persist_between_rounds( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + training_step_strategy=step_strategy, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_lrs == [[0.2], [0.1]] + assert trainer.lr_scheduler.last_epoch == 2 + assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) + + +def test_preserved_optimizer_state_is_local_to_logical_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + client1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + + trainer.set_client_id(2) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + assert step_strategy.pre_step_states[1] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[2]) + saved_state = _first_param_state(client1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + + +def test_preserved_state_stays_out_of_model_update_payload( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + update = trainer.obtain_model_update( + config, tiny_dataset, list(range(len(tiny_dataset))) + ) + + assert "optimizer_state" not in update + assert "scheduler_state" not in update + assert all(torch.is_tensor(value) for value in update.values()) + + +def test_preserved_state_invalidates_when_parameter_order_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=lambda: _two_layer_model("first", "second"), + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.model = _two_layer_model("second", "first") + trainer.context.model = trainer.model + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + + +def test_preserved_state_invalidates_when_optimizer_type_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.optimizer_strategy = SGDOptimizerStrategy(lr=0.1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + assert isinstance(trainer.optimizer, torch.optim.SGD) + + +def test_preserved_state_compatibility_rejects_shape_dtype_and_scheduler_changes( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(4) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + payload = copy.deepcopy(trainer._preserved_optimizer_states[4]) + + current_model = trainer.model + current_optimizer = trainer.optimizer_strategy.create_optimizer( + current_model, trainer.context + ) + changed_scheduler = StepLRSchedulerStrategy( + step_size=1, gamma=0.5 + ).create_scheduler(current_optimizer, trainer.context) + assert not trainer._preserved_state_is_compatible( + payload, current_model, current_optimizer, changed_scheduler + ) + + changed_shape_model = nn.Sequential( + OrderedDict([("linear", nn.Linear(2, 3))]) + ) + changed_shape_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_shape_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_shape_model, changed_shape_optimizer, None + ) + + changed_dtype_model = _linear_model().to(torch.float64) + changed_dtype_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_dtype_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_dtype_model, changed_dtype_optimizer, None + ) + + +@pytest.mark.parametrize("preserve_value", [None, False]) +def test_optimizer_state_is_not_restored_when_disabled_or_unset( + temp_config, tiny_dataset, one_step_config, preserve_value +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = dict(one_step_config) + if preserve_value is not None: + config["preserve_optimizer_state"] = preserve_value + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states == [{}, {}] + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 1 From 114f2a7f05461ec914d89d4fab360ecd9c672e55 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:21:42 -0400 Subject: [PATCH 09/21] Wired DiLoCo server selection. DT-414 adds server.type=diloco as a FedAvg-compatible server that injects DiLoCoAggregationStrategy while keeping algorithm.type=fedavg. The FedAvg delta path now filters non-weight reports before compute_weight_deltas(), so feature or metrics payloads cannot crash delta-only strategies before strategy eligibility handling. DiLoCo remains on aggregate_deltas and does not use inherited direct weight aggregation. Server-level tests cover registry/config selection, delta-path dispatch, inherited aggregate_weights avoidance, non-weight payload filtering, and existing FedAvg delta-strategy behavior. Validation: uv run pytest tests/servers/test_diloco_strategy.py -q; uv run pytest tests/servers/test_fedavg_strategy.py -q; uv run ruff check . --select I. --- plato/servers/diloco.py | 50 +++++++ plato/servers/fedavg.py | 27 +++- plato/servers/registry.py | 2 + tests/servers/test_diloco_strategy.py | 179 ++++++++++++++++++++++++++ 4 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 plato/servers/diloco.py diff --git a/plato/servers/diloco.py b/plato/servers/diloco.py new file mode 100644 index 000000000..bedbb5f05 --- /dev/null +++ b/plato/servers/diloco.py @@ -0,0 +1,50 @@ +"""FedAvg-compatible server using DiLoCo aggregation.""" + +from plato.config import Config +from plato.servers import fedavg +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy + + +class Server(fedavg.Server): + """Federated learning server with server-side DiLoCo outer aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = DiLoCoAggregationStrategy( + **self._aggregation_config() + ) + + super().__init__( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) + + @staticmethod + def _aggregation_config() -> dict: + """Read optional DiLoCo aggregation settings from [server.diloco].""" + config = getattr(Config().server, "diloco", None) + if config is None: + return {} + + keys = ( + "outer_optimizer", + "outer_learning_rate", + "outer_momentum", + "aggregation_weighting", + "apply_outer_optimizer_to", + ) + return {key: getattr(config, key) for key in keys if hasattr(config, key)} diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 5f862fc35..e357adc31 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -222,13 +222,20 @@ async def _process_reports(self): # Use delta aggregation (default path) # Computes the weight deltas by comparing the weights received with # the current global model weights - deltas_received = algorithm.compute_weight_deltas( - baseline_weights, weights_received + delta_updates, delta_weights_received = ( + self._weight_updates_and_payloads(self.updates, weights_received) + ) + deltas_received = ( + algorithm.compute_weight_deltas( + baseline_weights, delta_weights_received + ) + if delta_weights_received + else [] ) # Runs a framework-agnostic server aggregation algorithm, such as # the federated averaging algorithm logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid()) - deltas = await self.aggregate_deltas(self.updates, deltas_received) + deltas = await self.aggregate_deltas(delta_updates, deltas_received) # Updates the existing model weights from the provided deltas updated_weights = algorithm.update_weights(deltas) # Loads the new model weights @@ -299,6 +306,20 @@ def _should_prefer_weight_aggregation(self) -> bool: and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas ) + @staticmethod + def _weight_updates_and_payloads(updates, weights_received): + """Return update/payload pairs whose reports contain model weights.""" + delta_updates = [] + delta_weights_received = [] + + for update, weights in zip(updates, weights_received): + if getattr(update.report, "type", "weights") != "weights": + continue + delta_updates.append(update) + delta_weights_received.append(weights) + + return delta_updates, delta_weights_received + def clients_processed(self) -> None: """Additional work to be performed after client reports have been processed.""" diff --git a/plato/servers/registry.py b/plato/servers/registry.py index 2211b8972..b26465a8c 100644 --- a/plato/servers/registry.py +++ b/plato/servers/registry.py @@ -10,6 +10,7 @@ from plato.config import Config from plato.servers import ( + diloco, fedavg, fedavg_cs, fedavg_gan, @@ -30,6 +31,7 @@ registered_servers = { "fedavg": fedavg.Server, "fedavg_lora": fedavg.Server, + "diloco": diloco.Server, "fedavg_cross_silo": fedavg_cs.Server, "fedavg_gan": fedavg_gan.Server, "fedavg_personalized": fedavg_personalized.Server, diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 196928af8..f35467344 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -35,6 +35,60 @@ def compute_weight_deltas(self, baseline_weights, weights_list): ] +class ServerAlgorithm(DummyAlgorithm): + """Algorithm stub for exercising FedAvg-compatible server dispatch.""" + + def __init__(self, baseline): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + self.delta_payloads = None + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.current.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + self.delta_payloads = weights_list + return super().compute_weight_deltas(baseline_weights, weights_list) + + def update_weights(self, deltas): + self.current = { + name: self.current[name] + deltas[name] for name in self.current + } + return self.extract_weights() + + def load_weights(self, weights): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in weights.items() + } + + +class RecordingDiLoCoStrategy(DiLoCoAggregationStrategy): + """DiLoCo strategy recording server dispatch calls.""" + + def __init__(self): + super().__init__( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + self.delta_calls = 0 + self.last_updates = None + self.last_deltas = None + + async def aggregate_deltas(self, updates, deltas_received, context): + self.delta_calls += 1 + self.last_updates = updates + self.last_deltas = deltas_received + return await super().aggregate_deltas(updates, deltas_received, context) + + class MixedStateModel(torch.nn.Module): """Model exposing trainable, frozen, floating-buffer, and integer state.""" @@ -87,12 +141,137 @@ def _update(num_samples, report_type="weights"): ) +def _server_update(payload, num_samples=1, report_type="weights"): + update = _update(num_samples, report_type) + update.client_id = len(str(payload)) + update.report.accuracy = 0.5 + update.report.processing_time = 0.1 + update.report.comm_time = 0.1 + update.report.training_time = 0.1 + update.payload = payload + return update + + def _aggregate(strategy, updates, deltas, baseline=None, model=None): return asyncio.run( strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) ) +def test_diloco_server_type_uses_fedavg_algorithm_and_strategy(temp_config): + """server.type=diloco should select a FedAvg-compatible DiLoCo server.""" + from plato.algorithms import registry as algorithms_registry + from plato.config import Config + from plato.servers import diloco as diloco_server + from plato.servers import fedavg + from plato.servers import registry as servers_registry + + Config().server.type = "diloco" + Config().algorithm.type = "fedavg" + Config().server.diloco = SimpleNamespace( + outer_optimizer="sgd", + outer_learning_rate=0.25, + outer_momentum=0.1, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server = servers_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server, fedavg.Server) + assert isinstance(server.aggregation_strategy, DiLoCoAggregationStrategy) + assert server.aggregation_strategy.outer_optimizer == "sgd" + assert server.aggregation_strategy.outer_learning_rate == 0.25 + assert server.aggregation_strategy.outer_momentum == 0.1 + assert server.aggregation_strategy.aggregation_weighting == "num_samples" + assert server.aggregation_strategy.apply_outer_optimizer_to == "all_floating" + assert Config().algorithm.type == "fedavg" + assert "diloco" not in algorithms_registry.registered_algorithms + + +def test_diloco_server_process_reports_uses_delta_aggregation(temp_config): + """DiLoCo server processing should reach the delta aggregation path.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [ + _server_update({"w": torch.tensor([2.0])}), + _server_update({"w": torch.tensor([4.0])}), + ] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert strategy.last_updates == server.updates + assert len(strategy.last_deltas) == 2 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([3.0])) + + +def test_diloco_server_does_not_use_inherited_weight_aggregation(temp_config): + """DiLoCo must not bypass delta aggregation via inherited FedAvg weights.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + + async def fail_if_called(*_args, **_kwargs): + raise AssertionError("Inherited aggregate_weights() must not be called.") + + strategy.aggregate_weights = fail_if_called + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [_server_update({"w": torch.tensor([2.0])})] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_diloco_server_filters_non_weight_reports_before_delta_computation( + temp_config, +): + """Non-weight payloads should not reach compute_weight_deltas().""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + weight_payload = {"w": torch.tensor([2.0])} + server.updates = [ + _server_update("feature payload", report_type="features"), + _server_update({"metrics": 1.0}, report_type="metrics"), + _server_update(weight_payload), + ] + + asyncio.run(server._process_reports()) + + assert server.algorithm.delta_payloads == [weight_payload] + assert strategy.last_updates == [server.updates[2]] + assert len(strategy.last_deltas) == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" strategy = DiLoCoAggregationStrategy( From da0f3413af89af9c31716feb98f7bbaa14010585 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:29:27 -0400 Subject: [PATCH 10/21] Persisted optimizer state across train subprocesses. DT-420 extends trainer.preserve_optimizer_state to ComposableTrainer subprocess training by using a local optimizer-state sidecar under the configured model path. Child training loads any preserved sidecar before train_model(), saves updated optimizer and scheduler state after training, and the parent reloads the sidecar after the trained model is loaded. Missing, unreadable, invalid, or incompatible state falls back to fresh optimizer/scheduler state with explicit logging. Tests cover parent reload, optimizer state persistence across two subprocess rounds, scheduler progress, invalid sidecar reset, disabled behavior, and payload non-leakage. State remains local and is not added to network payloads. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "subprocess and (optimizer_state or scheduler_state)" -q; uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k "subprocess and (optimizer_state or scheduler_state)" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 105 ++++++++++- .../test_composable_optimizer_state.py | 163 ++++++++++++++++++ 2 files changed, 266 insertions(+), 2 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 8f16f0656..fc293302d 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -292,6 +292,11 @@ def _restore_preserved_optimizer_state(self) -> None: if not self._preserved_state_is_compatible( payload, model, self.optimizer, self.lr_scheduler ): + logging.info( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + ) self._preserved_optimizer_states.pop(self.client_id, None) return @@ -302,8 +307,9 @@ def _restore_preserved_optimizer_state(self) -> None: self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) except Exception as error: - logging.debug( - "[Client #%d] Discarding incompatible optimizer state: %s", + logging.warning( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state: %s", self.client_id, error, ) @@ -338,6 +344,82 @@ def _save_preserved_optimizer_state(self) -> None: ), } + def _optimizer_state_filename(self, run_id: str) -> str: + """Return the local optimizer-state handoff filename.""" + model_name = Config().trainer.model_name + return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + + def _optimizer_state_path(self, filename: str) -> str: + """Return the local optimizer-state handoff path.""" + return os.path.join(Config().params["model_path"], filename) + + def _save_preserved_optimizer_state_file(self, filename: str) -> None: + """Persist preserved optimizer state for subprocess handoff.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None: + return + + model_path = Config().params["model_path"] + os.makedirs(model_path, exist_ok=True) + state_path = self._optimizer_state_path(filename) + tmp_path = f"{state_path}.{os.getpid()}.tmp" + + try: + with open(tmp_path, "wb") as state_file: + pickle.dump(copy.deepcopy(payload), state_file) + os.replace(tmp_path, state_path) + except Exception as error: + if os.path.exists(tmp_path): + os.remove(tmp_path) + logging.warning( + "[Client #%d] Failed to persist optimizer state to %s: %s", + self.client_id, + state_path, + error, + ) + + def _load_preserved_optimizer_state_file( + self, filename: str, *, clear_on_missing: bool = False + ) -> None: + """Load preserved optimizer state from a subprocess handoff file.""" + state_path = self._optimizer_state_path(filename) + if not os.path.exists(state_path): + if clear_on_missing: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.info( + "[Client #%d] No persisted optimizer state found at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return + + try: + with open(state_path, "rb") as state_file: + payload = pickle.load(state_file) + except Exception as error: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding unreadable optimizer state at %s; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + state_path, + error, + ) + return + + if not isinstance(payload, dict): + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding invalid optimizer state at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return + + self._preserved_optimizer_states[self.client_id] = payload + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -545,8 +627,16 @@ def simulate_sleep_time(self): def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" + preserve_optimizer_state = self._preserve_optimizer_state(config) + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename(config["run_id"]) + self._load_preserved_optimizer_state_file(optimizer_state_filename) + self.train_model(config, trainset, sampler, **kwargs) + if preserve_optimizer_state: + self._save_preserved_optimizer_state_file(optimizer_state_filename) + model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" self.save_model(filename) @@ -911,6 +1001,12 @@ def train(self, trainset, sampler, **kwargs) -> float: if "max_concurrency" in config: tic = time.perf_counter() + preserve_optimizer_state = self._preserve_optimizer_state(config) + optimizer_state_filename = None + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename( + config["run_id"] + ) if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -963,6 +1059,11 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error + if optimizer_state_filename is not None: + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) + toc = time.perf_counter() self.pause_training() else: diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index f37289625..59f57fd2e 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -1,13 +1,18 @@ """Tests for in-process optimizer state preservation in ComposableTrainer.""" import copy +import os +import pickle +import sys from collections import OrderedDict +from pathlib import Path import pytest import torch import torch.nn as nn from torch.utils.data import TensorDataset +from plato.config import Config from plato.trainers.composable import ComposableTrainer from plato.trainers.strategies import ( AdamWOptimizerStrategy, @@ -101,6 +106,38 @@ def _state_step(param_state): return int(step) +def _configure_subprocess_training( + monkeypatch, + tmp_path, + *, + preserve_optimizer_state, +): + """Configure parent and spawned child processes to share local artifacts.""" + model_path = Path(tmp_path) / "models" / "pretrained" + model_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(Path(tmp_path) / "checkpoints") + Config.params["run_id"] = "subprocess-optimizer-state" + os.makedirs(Config.params["checkpoint_path"], exist_ok=True) + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + max_concurrency=1, + preserve_optimizer_state=preserve_optimizer_state, + batch_size=4, + epochs=1, + ) + + +def _cached_optimizer_step(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return _state_step(_first_param_state(payload["optimizer_state"]["state"])) + + +def _cached_scheduler_last_epoch(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return payload["scheduler_state"]["last_epoch"] + + def test_adamw_moment_buffers_persist_between_rounds_for_same_client( temp_config, tiny_dataset, one_step_config ): @@ -149,6 +186,132 @@ def test_scheduler_state_and_lr_progress_persist_between_rounds( assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) +def test_subprocess_optimizer_state_parent_reloads_after_child( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer.client_id in trainer._preserved_optimizer_states + assert _cached_optimizer_step(trainer) == 1 + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert state_path.exists() + assert "optimizer_state" not in trainer.obtain_model_update( + { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "payload-check", + "preserve_optimizer_state": True, + }, + tiny_dataset, + list(range(len(tiny_dataset))), + ) + + +def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 2 + + +def test_subprocess_scheduler_state_persists_across_rounds( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert _cached_scheduler_last_epoch(trainer) == 2 + assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_invalid_optimizer_state_resets_safely( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with open(state_path, "wb") as state_file: + pickle.dump({"optimizer_type": torch.optim.SGD}, state_file) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert payload["optimizer_type"] is torch.optim.AdamW + assert _cached_optimizer_step(trainer) == 1 + + +def test_subprocess_optimizer_state_is_not_persisted_when_disabled( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=False + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer._preserved_optimizer_states == {} + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert not state_path.exists() + + def test_preserved_optimizer_state_is_local_to_logical_client( temp_config, tiny_dataset, one_step_config ): From c7307320096a3f6881a0f185a7410303931aa9aa Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:39:42 -0400 Subject: [PATCH 11/21] Hardened subprocess optimizer state handoff. DT-421 review found that a missing optimizer sidecar could leave inherited parent cache active in the child, and that parent reload could confuse stale input with current child output. The child now clears inherited cache when the input sidecar is missing. Subprocess training writes to a unique child output sidecar, the parent loads that output, promotes it to the stable input sidecar for the next round, and removes stale stable state if child output is missing or invalid. Added regressions for missing input sidecars clearing inherited cache and missing child output removing stale input sidecars. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "subprocess or sidecar" -q; uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k "subprocess and (optimizer_state or scheduler_state)" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 82 ++++++++++++++++--- .../test_composable_optimizer_state.py | 76 +++++++++++++++++ 2 files changed, 146 insertions(+), 12 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index fc293302d..78ceb07bf 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -349,15 +349,21 @@ def _optimizer_state_filename(self, run_id: str) -> str: model_name = Config().trainer.model_name return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + def _optimizer_state_output_filename(self, run_id: str) -> str: + """Return a unique subprocess optimizer-state output filename.""" + model_name = Config().trainer.model_name + token = time.time_ns() + return f"{model_name}_{self.client_id}_{run_id}_{os.getpid()}_{token}.optim.pkl" + def _optimizer_state_path(self, filename: str) -> str: """Return the local optimizer-state handoff path.""" return os.path.join(Config().params["model_path"], filename) - def _save_preserved_optimizer_state_file(self, filename: str) -> None: + def _save_preserved_optimizer_state_file(self, filename: str) -> bool: """Persist preserved optimizer state for subprocess handoff.""" payload = self._preserved_optimizer_states.get(self.client_id) if payload is None: - return + return False model_path = Config().params["model_path"] os.makedirs(model_path, exist_ok=True) @@ -368,6 +374,7 @@ def _save_preserved_optimizer_state_file(self, filename: str) -> None: with open(tmp_path, "wb") as state_file: pickle.dump(copy.deepcopy(payload), state_file) os.replace(tmp_path, state_path) + return True except Exception as error: if os.path.exists(tmp_path): os.remove(tmp_path) @@ -377,10 +384,11 @@ def _save_preserved_optimizer_state_file(self, filename: str) -> None: state_path, error, ) + return False def _load_preserved_optimizer_state_file( self, filename: str, *, clear_on_missing: bool = False - ) -> None: + ) -> bool: """Load preserved optimizer state from a subprocess handoff file.""" state_path = self._optimizer_state_path(filename) if not os.path.exists(state_path): @@ -392,7 +400,7 @@ def _load_preserved_optimizer_state_file( self.client_id, state_path, ) - return + return False try: with open(state_path, "rb") as state_file: @@ -406,7 +414,7 @@ def _load_preserved_optimizer_state_file( state_path, error, ) - return + return False if not isinstance(payload, dict): self._preserved_optimizer_states.pop(self.client_id, None) @@ -416,9 +424,38 @@ def _load_preserved_optimizer_state_file( self.client_id, state_path, ) - return + return False self._preserved_optimizer_states[self.client_id] = payload + return True + + def _remove_preserved_optimizer_state_file(self, filename: str) -> None: + """Remove a local optimizer-state sidecar if it exists.""" + state_path = self._optimizer_state_path(filename) + try: + os.remove(state_path) + except FileNotFoundError: + return + except OSError as error: + logging.warning( + "[Client #%d] Failed to remove optimizer state at %s: %s", + self.client_id, + state_path, + error, + ) + + def _finish_subprocess_optimizer_state( + self, input_filename: str, output_filename: str + ) -> None: + """Load the child output sidecar and promote it for the next round.""" + loaded = self._load_preserved_optimizer_state_file( + output_filename, clear_on_missing=True + ) + if loaded: + self._save_preserved_optimizer_state_file(input_filename) + self._remove_preserved_optimizer_state_file(output_filename) + else: + self._remove_preserved_optimizer_state_file(input_filename) @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: @@ -629,13 +666,22 @@ def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" preserve_optimizer_state = self._preserve_optimizer_state(config) if preserve_optimizer_state: - optimizer_state_filename = self._optimizer_state_filename(config["run_id"]) - self._load_preserved_optimizer_state_file(optimizer_state_filename) + optimizer_state_filename = config.get( + "_optimizer_state_input_filename", + self._optimizer_state_filename(config["run_id"]), + ) + optimizer_state_output_filename = config.get( + "_optimizer_state_output_filename", + optimizer_state_filename, + ) + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) self.train_model(config, trainset, sampler, **kwargs) if preserve_optimizer_state: - self._save_preserved_optimizer_state_file(optimizer_state_filename) + self._save_preserved_optimizer_state_file(optimizer_state_output_filename) model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" @@ -1003,10 +1049,19 @@ def train(self, trainset, sampler, **kwargs) -> float: tic = time.perf_counter() preserve_optimizer_state = self._preserve_optimizer_state(config) optimizer_state_filename = None + optimizer_state_output_filename = None if preserve_optimizer_state: optimizer_state_filename = self._optimizer_state_filename( config["run_id"] ) + optimizer_state_output_filename = self._optimizer_state_output_filename( + config["run_id"] + ) + config = { + **config, + "_optimizer_state_input_filename": optimizer_state_filename, + "_optimizer_state_output_filename": optimizer_state_output_filename, + } if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -1059,9 +1114,12 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error - if optimizer_state_filename is not None: - self._load_preserved_optimizer_state_file( - optimizer_state_filename, clear_on_missing=True + if ( + optimizer_state_filename is not None + and optimizer_state_output_filename is not None + ): + self._finish_subprocess_optimizer_state( + optimizer_state_filename, optimizer_state_output_filename ) toc = time.perf_counter() diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 59f57fd2e..9a8cab29e 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -262,6 +262,82 @@ def test_subprocess_scheduler_state_persists_across_rounds( assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) +def test_subprocess_missing_sidecar_clears_inherited_parent_cache( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + source_trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + source_trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + assert _cached_optimizer_step(source_trainer) == 1 + + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + trainer._preserved_optimizer_states[7] = copy.deepcopy( + source_trainer._preserved_optimizer_states[7] + ) + + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + state_path.unlink(missing_ok=True) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 1 + + +def test_missing_subprocess_output_removes_stale_input_sidecar( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + input_filename = trainer._optimizer_state_filename(Config.params["run_id"]) + missing_output_filename = trainer._optimizer_state_output_filename( + Config.params["run_id"] + ) + assert trainer._save_preserved_optimizer_state_file(input_filename) + input_path = Path(Config.params["model_path"]) / input_filename + assert input_path.exists() + + trainer._finish_subprocess_optimizer_state( + input_filename, missing_output_filename + ) + + assert trainer.client_id not in trainer._preserved_optimizer_states + assert not input_path.exists() + + def test_subprocess_invalid_optimizer_state_resets_safely( temp_config, monkeypatch, tmp_path, tiny_dataset ): From 33129d083443f932bbf9a4b6fc4e9c54a890d375 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:56:47 -0400 Subject: [PATCH 12/21] Added DiLoCo payload safety coverage. DT-422 adds regression tests proving client-local optimizer and scheduler state remains local when trainer.preserve_optimizer_state is enabled. Client tests now cover the FedAvg/DiLoCo-compatible in-process path and subprocess sidecar path, asserting outbound payloads contain exactly model state tensors and reject optimizer_state, scheduler_state, global_step, local metadata, and sidecar filename keys. Trainer tests also verify model-update payloads stay model-only while optimizer and scheduler state are persisted locally. Validation: uv run pytest tests/clients -k "payload or simple" -q; uv run pytest tests/trainers -k "optimizer_state or scheduler_state" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- tests/clients/test_simple_client.py | 155 ++++++++++++++++-- .../test_composable_optimizer_state.py | 29 +++- 2 files changed, 171 insertions(+), 13 deletions(-) diff --git a/tests/clients/test_simple_client.py b/tests/clients/test_simple_client.py index 43ffbb907..11b1bd8e0 100644 --- a/tests/clients/test_simple_client.py +++ b/tests/clients/test_simple_client.py @@ -1,7 +1,10 @@ """End-to-end smoke tests for the strategy-based client runtime.""" import asyncio +import pickle +import sys from dataclasses import dataclass +from pathlib import Path import torch from torch.utils.data import Dataset @@ -10,6 +13,20 @@ from plato.clients import simple from plato.config import Config from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import AdamWOptimizerStrategy, StepLRSchedulerStrategy +from tests.test_utils.fakes import NoOpCommunicationStrategy + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} class ToyDataset(Dataset): @@ -48,37 +65,155 @@ def get_test_set(self): return self._test -def _build_client(): +def _build_client(trainer=ComposableTrainer): """Instantiate a client wired with custom model, datasource, and trainer.""" return simple.Client( model=torch.nn.Linear(4, 2), datasource=ToyDatasource, - trainer=ComposableTrainer, + trainer=trainer, algorithm=lambda trainer: fedavg.Algorithm(trainer), ) -def test_simple_client_trains_with_default_strategies(temp_config): - """A simple client should complete one training round using the strategy stack.""" - Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) +def _build_stateful_trainer(model=None, callbacks=None): + """Build a trainer whose local optimizer and scheduler state is non-empty.""" + return ComposableTrainer( + model=model, + callbacks=callbacks, + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) - client = _build_client() - # Assign identifiers expected by the client runtime. +def _configure_one_round_client(client): + """Prepare a client for a deterministic single training round.""" client.client_id = 1 client._context.client_id = 1 client.current_round = 1 client._context.current_round = 1 - # Prepare data and runtime components. client._load_data() client.configure() client._allocate_data() + +def _disable_payload_processors(client): + """Keep the test focused on decoded client-server model payload contents.""" + client.inbound_processor = None + client.outbound_processor = None + client._context.inbound_processor = None + client._context.outbound_processor = None + + +def _assert_model_weight_payload(payload, model): + """Assert that an outbound payload contains exactly model state tensors.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + for name, expected in model_state.items(): + assert torch.equal(payload[name], expected) + + +def _assert_preserved_state_is_local(trainer, client_id): + """Assert optimizer and scheduler persistence exists only in trainer state.""" + state = trainer._preserved_optimizer_states[client_id] + + assert state["optimizer_state"]["state"] + assert state["scheduler_state"] is not None + assert state["scheduler_state"]["last_epoch"] >= 1 + assert state["scheduler_state"]["_step_count"] >= 2 + + +def test_simple_client_trains_with_default_strategies(temp_config): + """A simple client should complete one training round using the strategy stack.""" + Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) + + client = _build_client() + + _configure_one_round_client(client) + report, payload = asyncio.run(client._train()) assert report.client_id == 1 # With partition_size=4 each client receives four samples. assert report.num_samples == 4 - assert isinstance(payload, dict) - assert all(isinstance(value, torch.Tensor) for value in payload.values()) + _assert_model_weight_payload(payload, client.trainer.model) + + +def test_simple_client_payload_excludes_local_state_when_persistence_enabled( + temp_config, +): + """FedAvg/DiLoCo client payloads stay model-only with local persistence.""" + Config.params["run_id"] = "client-payload-in-process" + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + _assert_preserved_state_is_local(client.trainer, client.client_id) + _assert_model_weight_payload(sent_payload, client.trainer.model) + + +def test_simple_client_subprocess_payload_excludes_local_state_sidecar( + temp_config, monkeypatch, tmp_path +): + """Subprocess persistence uses a sidecar without changing server payloads.""" + model_path = Path(tmp_path) / "models" / "pretrained" + checkpoint_path = Path(tmp_path) / "checkpoints" + model_path.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(checkpoint_path) + Config.params["run_id"] = "client-payload-subprocess" + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + max_concurrency=1, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + state_path = ( + Path(Config.params["model_path"]) + / client.trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with state_path.open("rb") as state_file: + sidecar_state = pickle.load(state_file) + + _assert_preserved_state_is_local(client.trainer, client.client_id) + assert sidecar_state["optimizer_state"]["state"] + assert sidecar_state["scheduler_state"] is not None + _assert_model_weight_payload(sent_payload, client.trainer.model) diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 9a8cab29e..3620c6248 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -22,6 +22,18 @@ StepLRSchedulerStrategy, ) +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + @pytest.fixture def tiny_dataset(): @@ -138,6 +150,14 @@ def _cached_scheduler_last_epoch(trainer): return payload["scheduler_state"]["last_epoch"] +def _assert_model_update_contains_only_model_weights(update, model): + model_state = model.state_dict() + + assert set(update) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(update) + assert all(torch.is_tensor(value) for value in update.values()) + + def test_adamw_moment_buffers_persist_between_rounds_for_same_client( temp_config, tiny_dataset, one_step_config ): @@ -424,16 +444,19 @@ def test_preserved_state_stays_out_of_model_update_payload( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), ) + trainer.set_client_id(5) config = {**one_step_config, "preserve_optimizer_state": True} update = trainer.obtain_model_update( config, tiny_dataset, list(range(len(tiny_dataset))) ) + preserved_state = trainer._preserved_optimizer_states[trainer.client_id] - assert "optimizer_state" not in update - assert "scheduler_state" not in update - assert all(torch.is_tensor(value) for value in update.values()) + assert preserved_state["optimizer_state"]["state"] + assert preserved_state["scheduler_state"]["last_epoch"] >= 1 + _assert_model_update_contains_only_model_weights(update, trainer.model) def test_preserved_state_invalidates_when_parameter_order_changes( From 21bf980c0c86c5bc732f3aa60613d5dd572c98a5 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:06:55 -0400 Subject: [PATCH 13/21] Added round-aware local-step sampling. DT-428 prevents exact local-step training from replaying the same deterministic sampler prefix when H is smaller than one epoch and the train loader is recreated each round. The data-loader strategies now materialize supported sampler streams only when trainer.local_steps_per_round is active, rotate the stream by the deterministic round offset, and leave epoch-based training unchanged when local-step limits are unset. Unsupported non-materializable sampler objects log a clear warning and fall back unchanged. Added focused red/green coverage showing two short local-step rounds for the same client consume different prefixes while repeated runs with the same round sequence remain deterministic. Validation: uv run pytest tests/trainers -k "local_steps or data_loader or sampler" -q; uv run pytest tests/samplers -q; uv run ruff check . --select I; git diff --check. --- plato/trainers/strategies/data_loader.py | 81 +++++++++++++++++++++ tests/trainers/test_composable_trainer.py | 88 ++++++++++++++++++++++- 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 91e5a0482..f65bbf3f3 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -20,6 +20,19 @@ AdjustFn = Callable[[TrainingContext], int] +class _FixedOrderSampler(torch.utils.data.Sampler): + """Sampler that yields precomputed dataset indices in order.""" + + def __init__(self, indices: list[int]): + self._indices = indices + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + + def _context_uses_cuda(context: TrainingContext) -> bool: """Return True if the training context targets a CUDA device.""" device = getattr(context, "device", None) @@ -40,6 +53,54 @@ def _resolve_pin_memory(setting: bool | None, context: TrainingContext) -> bool: return _context_uses_cuda(context) +def _local_step_stream_start( + context: TrainingContext, samples_per_round: int, stream_length: int +) -> int: + """Return the deterministic stream offset for this local-step round.""" + current_round = int(getattr(context, "current_round", 0) or 0) + if current_round > 0: + return ((current_round - 1) * samples_per_round) % stream_length + + offset = int(context.state.get("_local_step_sampler_stream_offset", 0)) + context.state["_local_step_sampler_stream_offset"] = offset + samples_per_round + return offset % stream_length + + +def _apply_local_step_sampling_stream( + sampler_obj, batch_size: int, context: TrainingContext +): + """Advance deterministic samplers across short local-step rounds.""" + local_steps_per_round = context.state.get("local_steps_per_round") + if local_steps_per_round is None or sampler_obj is None: + return sampler_obj + + samples_per_round = int(local_steps_per_round) * int(batch_size) + if samples_per_round <= 0: + return sampler_obj + + try: + indices = list(iter(sampler_obj)) + except TypeError: + logging.warning( + "Sampler %s cannot be materialized for round-aware local-step " + "sampling; using it unchanged. Consecutive short local rounds may " + "replay the same sampler prefix.", + type(sampler_obj), + ) + return sampler_obj + + if len(indices) == 0: + return sampler_obj + + start = _local_step_stream_start(context, samples_per_round, len(indices)) + if start == 0: + ordered_indices = indices + else: + ordered_indices = indices[start:] + indices[:start] + + return _FixedOrderSampler(ordered_indices) + + class DefaultDataLoaderStrategy(DataLoaderStrategy): """ Default data loader strategy. @@ -100,6 +161,10 @@ def create_train_loader( sampler_obj = None shuffle = self.shuffle + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + if sampler is None and not shuffle: logging.warning( "Data loader strategy received no sampler; falling back to SequentialSampler." @@ -174,6 +239,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -239,6 +308,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -320,6 +393,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, actual_batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=actual_batch_size, @@ -383,6 +460,10 @@ def create_train_loader( sampler_obj = None shuffle = True + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index d5daeb4b4..441ab7895 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -8,7 +8,7 @@ import pytest import torch import torch.nn as nn -from torch.utils.data import TensorDataset +from torch.utils.data import SubsetRandomSampler, TensorDataset from plato.callbacks.trainer import TrainerCallback from plato.config import Config @@ -184,6 +184,19 @@ def test_multiple_epochs(self, simple_model, simple_dataset): class TestComposableTrainerLocalSteps: """Test local optimizer-step limits for DiLoCo-style local work.""" + class DeterministicPlatoSampler: + def __init__(self, indices, seed=47): + self.indices = list(indices) + self.seed = seed + + def get(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + return SubsetRandomSampler(self.indices, generator=generator) + + def num_samples(self): + return len(self.indices) + class CountingCallback(TrainerCallback): def __init__(self): self.train_run_end_called = False @@ -230,6 +243,33 @@ def training_step( context=context, ) + class RecordingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.samples_by_round = {} + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + sample_ids = examples[:, 0].detach().cpu().int().tolist() + self.samples_by_round.setdefault(context.current_round, []).extend( + sample_ids + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + class DelayedOptimizerStepStrategy(TrainingStepStrategy): def __init__(self, accumulation_steps=2, finalize_steps=True): self.accumulation_steps = accumulation_steps @@ -420,6 +460,52 @@ def test_epoch_behavior_is_unchanged_when_local_steps_unset( assert trainer.current_epoch == 2 assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( + self, simple_model, simple_config + ): + dataset_size = 10 + features = torch.arange(dataset_size, dtype=torch.float32).view(-1, 1) + features = features.repeat(1, 10) + labels = torch.arange(dataset_size) % 2 + dataset = TensorDataset(features, labels) + config = { + **simple_config, + "batch_size": 1, + "epochs": 1, + "local_steps_per_round": 3, + } + sampler = self.DeterministicPlatoSampler(range(dataset_size)) + step_strategy = self.RecordingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + trainer.set_client_id(2) + + for round_number in (1, 2): + trainer.current_round = round_number + trainer.train_model(config, dataset, sampler) + + round_one_samples = step_strategy.samples_by_round[1] + round_two_samples = step_strategy.samples_by_round[2] + + assert len(round_one_samples) == config["local_steps_per_round"] + assert len(round_two_samples) == config["local_steps_per_round"] + assert round_one_samples != round_two_samples + + repeat_step_strategy = self.RecordingStepStrategy() + repeat_trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=repeat_step_strategy, + ) + repeat_trainer.set_client_id(2) + + for round_number in (1, 2): + repeat_trainer.current_round = round_number + repeat_trainer.train_model(config, dataset, sampler) + + assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From f6e81965a1065b296a243322cf19b2773e688a33 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:11:30 -0400 Subject: [PATCH 14/21] Handled non-materializable local-step samplers. DT-429 review found that the round-aware local-step sampler wrapper only treated TypeError as an unsupported materialization path. Samplers that raise NotImplementedError during iteration should also warn and fall back unchanged instead of failing while setting up the data loader. This patch catches NotImplementedError in the same warning/fallback path and adds regression coverage with a non-materializable sampler to verify the warning and unchanged sampler handoff. Validation: uv run pytest tests/trainers/test_composable_trainer.py -k "non_materializable or local_steps" -q; uv run pytest tests/trainers -k "local_steps or data_loader or sampler" -q; uv run pytest tests/samplers -q; uv run ruff check plato/trainers/strategies/data_loader.py tests/trainers/test_composable_trainer.py --select I; git diff --check. --- plato/trainers/strategies/data_loader.py | 2 +- tests/trainers/test_composable_trainer.py | 30 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index f65bbf3f3..9d9c5dc0e 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -80,7 +80,7 @@ def _apply_local_step_sampling_stream( try: indices = list(iter(sampler_obj)) - except TypeError: + except (TypeError, NotImplementedError): logging.warning( "Sampler %s cannot be materialized for round-aware local-step " "sampling; using it unchanged. Consecutive short local rounds may " diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index 441ab7895..a2063c933 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -5,6 +5,8 @@ it works correctly in end-to-end training scenarios. """ +import logging + import pytest import torch import torch.nn as nn @@ -197,6 +199,13 @@ def get(self): def num_samples(self): return len(self.indices) + class NonMaterializableSampler(torch.utils.data.Sampler): + def __iter__(self): + raise NotImplementedError("This sampler cannot be materialized.") + + def __len__(self): + return 10 + class CountingCallback(TrainerCallback): def __init__(self): self.train_run_end_called = False @@ -506,6 +515,27 @@ def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + def test_local_step_sampling_warns_for_non_materializable_sampler( + self, simple_dataset, caplog + ): + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + sampler = self.NonMaterializableSampler() + + with caplog.at_level(logging.WARNING): + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + sampler, + batch_size=1, + context=context, + ) + + assert loader.sampler is sampler + assert ( + "cannot be materialized for round-aware local-step sampling" + in caplog.text + ) + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From 711fdb11691780d1a5e65394c99c4bd7457fca10 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:18:33 -0400 Subject: [PATCH 15/21] Added exact DiLoCo smoke configuration. DT-424 adds a small MNIST/LeNet DiLoCo smoke config that uses the faithful configuration contract: server.type=diloco, algorithm.type=fedavg, local_steps_per_round=2, preserve_optimizer_state=true, AdamW inner optimizer, Nesterov outer optimizer, uniform weighting, and parameter-only outer updates. The docs now explain how to run the smoke config, distinguish algorithm mechanics from reproducing the paper C4/model/pretraining setup, and document H semantics, mid-epoch stopping, round-aware small-H sampling, local-only optimizer and scheduler state, FedAvg equivalence conditions, and the parameter/buffer policy. The integration smoke test loads the real config, verifies the contract values, and checks that the server registry selects the DiLoCo server and DiLoCo aggregation strategy. Validation: uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run ruff check . --select I; git diff --check. --- configs/MNIST/diloco_lenet5_smoke.toml | 68 +++++++++++++++++++++++++ docs/docs/configurations/server.md | 31 +++++++++++ docs/docs/configurations/trainer.md | 14 +++++ docs/docs/development/diloco.md | 25 +++++++-- tests/integration/test_smoke_configs.py | 36 ++++++++++++- tests/integration/utils.py | 30 +++++++++++ 6 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 configs/MNIST/diloco_lenet5_smoke.toml diff --git a/configs/MNIST/diloco_lenet5_smoke.toml b/configs/MNIST/diloco_lenet5_smoke.toml new file mode 100644 index 000000000..0c17c2ebc --- /dev/null +++ b/configs/MNIST/diloco_lenet5_smoke.toml @@ -0,0 +1,68 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8000 +random_seed = 1 +simulate_wall_time = true +do_test = false + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +include = "mnist_iid.toml" +partition_size = 16 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 2 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_name = "lenet5" + +# DiLoCo local work H, counted in optimizer steps. +local_steps_per_round = 2 +preserve_optimizer_state = true + +epochs = 1 +batch_size = 4 +optimizer = "AdamW" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 2bb800237..42d0bad99 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -8,6 +8,7 @@ - `fedavg_personalized` a Federated Averaging server that supports all-purpose personalized federated learning by controlling when and which group of clients are to perform local personalization. - `fedavg_mpc_additive` a Federated Averaging server that reconstructs additive MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_additive` processor. - `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor. + - `diloco` a FedAvg-compatible server that applies DiLoCo outer aggregation. Use it with `algorithm.type = "fedavg"` and configure the outer optimizer under `[server.diloco]`. - `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use. - `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset. - `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights. @@ -124,6 +125,36 @@ Default value: `100` +!!! example "diloco" + Settings for `server.type = "diloco"`. DiLoCo reuses `algorithm.type = "fedavg"` for client weight extraction and global model loading, while the DiLoCo server turns client deltas into an outer-gradient update. + + ```toml + [server] + type = "diloco" + + [algorithm] + type = "fedavg" + + [server.diloco] + outer_optimizer = "nesterov" + outer_learning_rate = 0.7 + outer_momentum = 0.9 + aggregation_weighting = "uniform" + apply_outer_optimizer_to = "parameters" + ``` + + `aggregation_weighting = "uniform"` matches balanced IID worker smoke runs. `aggregation_weighting = "num_samples"` matches Plato's traditional sample-weighted FedAvg behavior. With outer SGD and `outer_learning_rate = 1.0`, uniform weighting is equivalent to uniform model averaging; with `num_samples`, it is equivalent to Plato-style sample-weighted FedAvg. + + `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. + + A runnable smoke configuration is available at `configs/MNIST/diloco_lenet5_smoke.toml`: + + ```bash + uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml + ``` + + The smoke configuration validates DiLoCo mechanics in Plato; it is not a C4/model/pretraining reproduction of the DiLoCo paper. + !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/configurations/trainer.md b/docs/docs/configurations/trainer.md index de9eb715e..05fef2f2d 100644 --- a/docs/docs/configurations/trainer.md +++ b/docs/docs/configurations/trainer.md @@ -56,6 +56,20 @@ !!! example "epochs" The total number of epochs in local training in each communication round. +!!! example "local_steps_per_round" + The DiLoCo local work value `H`, counted as completed client-local optimizer steps between synchronizations. + + `H` is not an epoch count, raw dataloader batch count, or gradient-accumulation micro-batch count. When gradient accumulation is enabled, only batches that trigger `optimizer.step()` increment `H`. + + `H` may be smaller than one epoch. In that case, local training stops mid-epoch after exactly `H` optimizer steps while still running normal trainer cleanup, callback completion, state persistence, and reporting. + + Small-`H` DiLoCo runs use round-aware sampling where supported so a logical client does not replay the same first `H` batches every round. Trainers or samplers that cannot count optimizer steps or advance the local stream faithfully must fail or warn clearly instead of silently approximating DiLoCo. + +!!! example "preserve_optimizer_state" + Whether client-local optimizer and scheduler state should persist across a logical client's local train runs. + + DiLoCo should set this to `true` with a stateful inner optimizer such as `AdamW`. Optimizer and scheduler state remains client-local and is not transmitted in client-server payloads. + !!! example "batch_size" The size of the mini-batch of data in each step (iteration) of the training loop. diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md index 94dfd49fc..b66eae9ad 100644 --- a/docs/docs/development/diloco.md +++ b/docs/docs/development/diloco.md @@ -1,14 +1,27 @@ # DiLoCo Design Contract -This note defines what Plato will call faithful DiLoCo for the initial -implementation. It is a contract for the implementation issues that follow; it -does not describe runtime behavior that already exists in Plato. +This note defines what Plato calls faithful DiLoCo in the current +implementation. Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo training loop inside Plato's federated runtime. It does not mean reproducing the paper's exact C4 dataset, model scale, tokenizer, hardware topology, pretraining duration, or final benchmark numbers. +## Smoke Configuration + +Plato includes a small MNIST/LeNet smoke configuration for checking the DiLoCo +mechanics: + +```bash +uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml +``` + +This smoke run validates configuration loading, DiLoCo server selection, local +optimizer-step work, client-local optimizer-state persistence, and server-side +outer aggregation. It is intentionally tiny and does not reproduce the C4 +language-model pretraining setup or the paper's reported metrics. + ## Algorithm Contract DiLoCo has two optimizer levels: @@ -132,8 +145,10 @@ rate `1.0` is valid only when both runs use the same weighting rule. Unsupported modes must fail clearly. They must not silently fall back to an approximate DiLoCo variant. Examples include trainer backends that cannot count local optimizer steps exactly, execution paths that cannot preserve -client-local optimizer and scheduler state, or payload paths that would send -optimizer state to the server. +client-local optimizer and scheduler state, samplers that cannot advance the +small-`H` local data stream across rounds, or payload paths that would send +optimizer state to the server. Experimental combinations that are allowed but +not faithful must warn clearly. ## Implementation Sequence diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 6dbc1fa08..938ec0c66 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -4,7 +4,8 @@ from __future__ import annotations -from importlib import import_module +from importlib import import_module, reload +from pathlib import Path from types import SimpleNamespace from typing import cast @@ -17,8 +18,11 @@ async_run, build_minimal_config, configure_environment, + configure_environment_from_path, ) +REPO_ROOT = Path(__file__).resolve().parents[2] + class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" @@ -97,6 +101,36 @@ def test_fedavg_lenet5_smoke(monkeypatch): assert server.accuracy >= 0 +@pytest.mark.integration +def test_diloco_lenet5_smoke_config_contract_loads(): + """Smoke config should load the faithful DiLoCo contract.""" + config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" + + with configure_environment_from_path(config_path) as config: + assert config.server.type == "diloco" + assert config.algorithm.type == "fedavg" + assert config.trainer.local_steps_per_round == 2 + assert config.trainer.preserve_optimizer_state is True + assert config.trainer.optimizer == "AdamW" + assert config.server.diloco.outer_optimizer == "nesterov" + assert config.server.diloco.outer_learning_rate == 0.7 + assert config.server.diloco.outer_momentum == 0.9 + assert config.server.diloco.aggregation_weighting == "uniform" + assert config.server.diloco.apply_outer_optimizer_to == "parameters" + + server_registry = reload(import_module("plato.servers.registry")) + diloco_server = import_module("plato.servers.diloco") + diloco_aggregation = import_module("plato.servers.strategies.aggregation") + + server = server_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance( + server.aggregation_strategy, + diloco_aggregation.DiLoCoAggregationStrategy, + ) + + @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3cb4a907b..4ff610756 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -107,6 +107,36 @@ def configure_environment(config_dict: dict): Config._instance = None +@contextlib.contextmanager +def configure_environment_from_path(config_path: Path): + """ + Context manager that initialises Config singleton from an existing config. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + Config._instance = None # reset singleton + Config.params = {} + + previous_env = os.environ.get("config_file") + previous_argv = sys.argv[:] + os.environ["config_file"] = str(config_path) + sys.argv = [ + previous_argv[0] if previous_argv else "pytest", + "--base", + tmp_dir, + ] + + try: + config = Config() + yield config + finally: + if previous_env is None: + os.environ.pop("config_file", None) + else: + os.environ["config_file"] = previous_env + sys.argv = previous_argv + Config._instance = None + + def async_run(coro): """Utility to execute the coroutine using asyncio.run (Python 3.7+).""" return asyncio.run(coro) From 082aaf1aa8980ce96ce0e5df0bf9498735a5e62b Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:28:21 -0400 Subject: [PATCH 16/21] Added end-to-end DiLoCo validation coverage. DT-426 adds final integration coverage for the faithful DiLoCo path using the exact MNIST smoke config. The test builds the configured DiLoCo server and simple client, runs local training with local_steps_per_round=2 and preserved optimizer state, verifies the outbound client payload remains model weights only, and processes deterministic server updates through the DiLoCo delta aggregation path. The validation would fail if the config selected ordinary FedAvg server aggregation, if local step control were ignored, or if the server bypassed aggregate_deltas. It directly checks the Nesterov outer update differs from ordinary FedAvg averaging, while relying on reviewed lower-level tests for small-H mid-epoch stopping, round-aware sampler non-replay, scheduler sidecar persistence, and broader payload leak coverage. Validation: uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run pytest tests/servers/test_diloco_strategy.py -q; uv run pytest tests/trainers -k "local_steps or optimizer_state or scheduler_state or data_loader" -q; uv run pytest tests/clients -k "payload or simple" -q; uv run ruff check . --select I; git diff --check. --- tests/integration/test_smoke_configs.py | 164 +++++++++++++++++++++++- 1 file changed, 160 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 938ec0c66..8dd1271a4 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -28,13 +28,14 @@ class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" def __init__(self, train_size: int = 4, test_size: int = 2): + generator = torch.Generator().manual_seed(13) self._train = TensorDataset( - torch.randn(train_size, 1, 28, 28), - torch.randint(0, 10, (train_size,)), + torch.randn(train_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (train_size,), generator=generator), ) self._test = TensorDataset( - torch.randn(test_size, 1, 28, 28), - torch.randint(0, 10, (test_size,)), + torch.randn(test_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (test_size,), generator=generator), ) def num_train_examples(self): @@ -47,6 +48,56 @@ def get_test_set(self): return self._test +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + + +def _model_weight_payload(payload, model): + """Assert that a client payload contains model weights only.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + +def _shifted_payload(weights, amount): + """Build a fake client model payload shifted from the server baseline.""" + shifted = {} + for name, value in weights.items(): + shifted[name] = value.clone() + if torch.is_floating_point(shifted[name]): + shifted[name] = shifted[name] + amount + return shifted + + +def _server_update(client_id, payload): + """Build a minimal server update carrying model weights.""" + return SimpleNamespace( + client_id=client_id, + report=SimpleNamespace( + client_id=client_id, + num_samples=1, + accuracy=0.5, + processing_time=0.1, + comm_time=0.1, + training_time=0.1, + type="weights", + ), + payload=payload, + ) + + @pytest.mark.integration def test_fedavg_lenet5_smoke(monkeypatch): """End-to-end smoke test for a minimal FedAvg run.""" @@ -131,6 +182,111 @@ def test_diloco_lenet5_smoke_config_contract_loads(): ) +@pytest.mark.integration +def test_diloco_lenet5_smoke_config_runs_faithful_path(monkeypatch): + """Exact DiLoCo smoke config exercises local work and outer aggregation.""" + config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" + + with configure_environment_from_path(config_path) as config: + datasources_registry = import_module("plato.datasources.registry") + processor_registry = import_module("plato.processors.registry") + server_registry = reload(import_module("plato.servers.registry")) + client_mod = import_module("plato.clients.simple") + config_mod = import_module("plato.config") + diloco_server = import_module("plato.servers.diloco") + fedavg_algorithm = import_module("plato.algorithms.fedavg") + diloco_aggregation = import_module("plato.servers.strategies.aggregation") + + fake_datasource = MNISTSmokeDatasource(train_size=32, test_size=4) + monkeypatch.setattr( + datasources_registry, + "get", + lambda *args, **kwargs: fake_datasource, + ) + monkeypatch.setattr( + processor_registry, + "get", + lambda *args, **kwargs: (None, None), + ) + + server = server_registry.get() + server.configure() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server.algorithm, fedavg_algorithm.Algorithm) + assert isinstance( + server.aggregation_strategy, + diloco_aggregation.DiLoCoAggregationStrategy, + ) + assert config.server.type == "diloco" + assert config.algorithm.type == "fedavg" + assert config.trainer.local_steps_per_round == 2 + assert config.trainer.preserve_optimizer_state is True + assert config.data.sampler == "iid" + + client = client_mod.Client() + client.client_id = 1 + client._context.client_id = 1 + client.current_round = 1 + client._context.current_round = 1 + client._load_data() + client.configure() + client._allocate_data() + client._load_payload(server.algorithm.extract_weights()) + + train_config = config.trainer._asdict() + train_config["run_id"] = config_mod.Config.params["run_id"] + client.trainer.current_round = client.current_round + client.trainer.train_model(train_config, client.trainset, client.sampler) + payload = client.algorithm.extract_weights() + + assert client.sampler.num_samples() == config.data.partition_size + assert client.trainer.context.state["local_steps_per_round"] == 2 + assert client.trainer.context.state["local_optimizer_steps"] == 2 + assert client.trainer.current_epoch == 1 + assert client.client_id in client.trainer._preserved_optimizer_states + assert client.trainer._preserved_optimizer_states[client.client_id][ + "optimizer_state" + ]["state"] + _model_weight_payload(payload, client.trainer.model) + + # Small-H mid-epoch stopping and round-aware sampler streaming are covered + # in TestComposableTrainerLocalSteps; this integration path verifies the + # exact smoke config enables those runtime flags with the supported sampler. + baseline = server.algorithm.extract_weights() + trainable_name = next(iter(dict(server.trainer.model.named_parameters()))) + server.updates = [ + _server_update(1, _shifted_payload(baseline, 1.0)), + _server_update(2, _shifted_payload(baseline, 3.0)), + ] + server.current_round = 1 + server.context.current_round = 1 + + delta_calls = [] + aggregate_deltas = server.aggregation_strategy.aggregate_deltas + + async def record_delta_aggregation(updates, deltas_received, context): + delta_calls.append((updates, deltas_received)) + return await aggregate_deltas(updates, deltas_received, context) + + monkeypatch.setattr( + server.aggregation_strategy, + "aggregate_deltas", + record_delta_aggregation, + ) + + async_run(server._process_reports()) + + updated = server.algorithm.extract_weights() + ordinary_fedavg_value = baseline[trainable_name] + 2.0 + faithful_diloco_value = baseline[trainable_name] + 2.0 * 1.9 * 0.7 + + assert len(delta_calls) == 1 + assert len(delta_calls[0][1]) == 2 + assert not torch.allclose(updated[trainable_name], ordinary_fedavg_value) + assert torch.allclose(updated[trainable_name], faithful_diloco_value) + + @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" From 1313d30897ee93f0189322367610b65590518a53 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 15:09:36 -0400 Subject: [PATCH 17/21] Restored optimizer state after moving models to device. Preserved AdamW state was loaded before ComposableTrainer moved the model to the trainer device. On GPU, PyTorch therefore mapped restored optimizer tensors to CPU and later optimizer.step() saw CUDA parameters with CPU Adam state, producing mixed-device runtime errors in later rounds. Move the model to the trainer device before optimizer construction and preserved-state restore so optimizer.load_state_dict() maps state tensors onto the same device as the optimizer parameters. Add a regression test that fails if preserved optimizer state is restored before model.to(). Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "restores_after_model_moves_to_device or optimizer_state or scheduler_state" -q; uv run pytest tests/trainers -k "local_steps or optimizer_state or scheduler_state or data_loader" -q; uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run pytest tests/clients -k "payload or simple" -q; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 12 +-- .../test_composable_optimizer_state.py | 78 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 78ceb07bf..cbfb4dd95 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -783,8 +783,13 @@ def train_model(self, config, trainset, sampler, **kwargs): self.context.state["grad_accum_loss_total"] = 0.0 self.context.state["grad_accum_loss_count"] = 0 - # Create optimizer using strategy + # Move the model before optimizer state restore so PyTorch maps restored + # state tensors onto the same device as the optimizer parameters. model = self._require_model() + model.to(self.device) + model.train() + + # Create optimizer using strategy self.optimizer = self.optimizer_strategy.create_optimizer(model, self.context) # Create LR scheduler using strategy @@ -794,11 +799,6 @@ def train_model(self, config, trainset, sampler, **kwargs): if preserve_optimizer_state: self._restore_preserved_optimizer_state() - # Move model to device - model = self._require_model() - model.to(self.device) - model.train() - # Training epochs total_epochs = config["epochs"] tic = time.perf_counter() diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 3620c6248..b005981f0 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -21,6 +21,7 @@ SGDOptimizerStrategy, StepLRSchedulerStrategy, ) +from plato.trainers.strategies.base import OptimizerStrategy, TrainingContext LOCAL_STATE_PAYLOAD_KEYS = { "optimizer_state", @@ -96,6 +97,51 @@ def _linear_model(): return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) +class DeviceTrackingModel(nn.Module): + """Model that records whether it has been moved to a trainer device.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.moved_to_trainer_device = False + + def forward(self, features): + return self.linear(features) + + def to(self, *args, **kwargs): + self.moved_to_trainer_device = True + return super().to(*args, **kwargs) + + +class RestoreOrderOptimizer(torch.optim.SGD): + """Optimizer that records whether state restore happens after model.to().""" + + def __init__(self, params, model: DeviceTrackingModel): + self.model = model + self.loaded_after_model_to = None + super().__init__(params, lr=0.01, momentum=0.9) + + def load_state_dict(self, state_dict): + self.loaded_after_model_to = self.model.moved_to_trainer_device + if not self.loaded_after_model_to: + raise AssertionError("optimizer state restored before model.to()") + return super().load_state_dict(state_dict) + + +class RestoreOrderOptimizerStrategy(OptimizerStrategy): + """Create restore-order-aware optimizers for regression tests.""" + + def __init__(self): + self.optimizers = [] + + def create_optimizer( + self, model: DeviceTrackingModel, context: TrainingContext + ) -> torch.optim.Optimizer: + optimizer = RestoreOrderOptimizer(model.parameters(), model) + self.optimizers.append(optimizer) + return optimizer + + def _two_layer_model(first_name="first", second_name="second"): return nn.Sequential( OrderedDict( @@ -184,6 +230,38 @@ def test_adamw_moment_buffers_persist_between_rounds_for_same_client( assert _state_step(final_param_state) == 2 +def test_preserved_optimizer_state_restores_after_model_moves_to_device( + temp_config, tiny_dataset, one_step_config +): + config = {**one_step_config, "preserve_optimizer_state": True} + source_trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=RestoreOrderOptimizerStrategy(), + ) + source_trainer.set_client_id(11) + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + restore_strategy = RestoreOrderOptimizerStrategy() + trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=restore_strategy, + ) + trainer.set_client_id(11) + trainer._preserved_optimizer_states[11] = copy.deepcopy( + source_trainer._preserved_optimizer_states[11] + ) + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert restore_strategy.optimizers[0].loaded_after_model_to is True + restored_state = _first_param_state( + trainer._preserved_optimizer_states[11]["optimizer_state"]["state"] + ) + assert "momentum_buffer" in restored_state + + def test_scheduler_state_and_lr_progress_persist_between_rounds( temp_config, tiny_dataset, one_step_config ): From f359d789633c8520e9d10c9d387e180925ce96ca Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 02:07:40 -0400 Subject: [PATCH 18/21] Logged DiLoCo outer optimizer application. Emit a server-side info log each time DiLoCo applies the configured outer optimizer to averaged client deltas. The log includes optimizer settings, aggregation weighting, apply policy, eligible update count, and optimized tensor count so runs show where the server update occurs. Validation: - uv run pytest tests/servers/test_diloco_strategy.py -q - uv run pytest tests/integration/test_smoke_configs.py -k diloco -q - uv run ruff check plato/servers/strategies/aggregation/diloco.py tests/servers/test_diloco_strategy.py --select I - git diff --check Co-authored-by: Codex --- .../servers/strategies/aggregation/diloco.py | 13 +++++++ tests/servers/test_diloco_strategy.py | 36 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 6504990cf..ddc2428fc 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -10,6 +10,7 @@ import asyncio import copy +import logging import numbers from collections.abc import Callable, Mapping from types import SimpleNamespace @@ -86,6 +87,18 @@ async def aggregate_deltas( server_delta, active_paths = self._apply_outer_optimizer( avg_delta, optimizer_paths ) + logging.info( + "[Server] DiLoCo outer optimizer applied: optimizer=%s " + "outer_lr=%g outer_momentum=%g weighting=%s apply_to=%s " + "eligible_updates=%d optimized_tensors=%d.", + self.outer_optimizer, + self.outer_learning_rate, + self.outer_momentum, + self.aggregation_weighting, + self.apply_outer_optimizer_to, + len(eligible), + len(optimizer_paths), + ) self._remove_stale_momentum(active_paths) return self._match_reference_structure(server_delta, eligible[0][1]) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index f35467344..4c739051d 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -1,6 +1,7 @@ """Tests for DiLoCo server-side outer aggregation.""" import asyncio +import logging from types import SimpleNamespace import pytest @@ -347,6 +348,41 @@ def test_sgd_uses_diloco_outer_gradient_sign(temp_config): assert torch.allclose(server_delta["w"], torch.tensor([1.0])) +def test_outer_optimizer_application_is_logged(temp_config, caplog): + """A DiLoCo aggregation should report the server-side outer optimizer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=0.7, + outer_momentum=0.9, + aggregation_weighting="uniform", + apply_outer_optimizer_to="parameters", + ) + model = torch.nn.Linear(1, 1, bias=False) + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + with caplog.at_level(logging.INFO): + _aggregate( + strategy, + [_update(1), _update(1)], + [ + {"weight": torch.tensor([[2.0]])}, + {"weight": torch.tensor([[4.0]])}, + ], + baseline, + model, + ) + + message = caplog.text + assert "DiLoCo outer optimizer applied" in message + assert "optimizer=nesterov" in message + assert "outer_lr=0.7" in message + assert "outer_momentum=0.9" in message + assert "weighting=uniform" in message + assert "apply_to=parameters" in message + assert "eligible_updates=2" in message + assert "optimized_tensors=1" in message + + def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): """Uniform mode should weight eligible clients equally.""" strategy = DiLoCoAggregationStrategy( From c365751e1bd160e0815f4b8bf5abd4e6ce6de28a Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 12:11:20 -0400 Subject: [PATCH 19/21] Added DiLoCo comparison configs and step-based scheduling. --- configs/CIFAR10/diloco_resnet18.toml | 79 ++++++++++++++++++ .../fedavg_resnet18_diloco_comparison.toml | 67 ++++++++++++++++ plato/trainers/composable.py | 25 +++++- plato/trainers/strategies/data_loader.py | 26 +++++- tests/trainers/test_composable_trainer.py | 80 +++++++++++++++++++ 5 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 configs/CIFAR10/diloco_resnet18.toml create mode 100644 configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml new file mode 100644 index 000000000..70193dee5 --- /dev/null +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -0,0 +1,79 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8021 + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.8 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 250 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml new file mode 100644 index 000000000..329572e8d --- /dev/null +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -0,0 +1,67 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8022 + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.8 + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +epochs = 5 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index cbfb4dd95..fc6feb886 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -206,6 +206,21 @@ def _preserve_optimizer_state(config: dict[str, Any]) -> bool: """Return whether optimizer state should survive local train runs.""" return bool(config.get("preserve_optimizer_state", False)) + @staticmethod + def _step_lr_scheduler_per_optimizer_step(config: dict[str, Any]) -> bool: + """Return whether LR scheduling should follow optimizer steps.""" + if config.get("local_steps_per_round") is None: + return False + + return getattr(Config().server, "type", None) == "diloco" + + def _step_lr_scheduler_after_optimizer_step( + self, step_lr_per_optimizer_step: bool + ) -> None: + """Advance step-based LR schedules after one completed optimizer step.""" + if step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + @staticmethod def _parameter_signature(name: str | None, parameter: torch.Tensor): """Build a compatibility signature for one model parameter.""" @@ -801,6 +816,7 @@ def train_model(self, config, trainset, sampler, **kwargs): # Training epochs total_epochs = config["epochs"] + step_lr_per_optimizer_step = self._step_lr_scheduler_per_optimizer_step(config) tic = time.perf_counter() training_stop_requested = False local_step_limit_reached = False @@ -874,6 +890,9 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) local_step_limit_reached = self._record_local_optimizer_step( local_steps_per_round ) @@ -930,6 +949,9 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) local_step_limit_reached = self._record_local_optimizer_step( local_steps_per_round ) @@ -979,7 +1001,8 @@ def compute_loss(outputs, labels_inner): self.context.state.pop("hf_optimizer_step_index", None) # LR scheduler step - self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + if not step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) # Handle optimizer params state update if needed if hasattr(self.optimizer, "params_state_update"): diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 9d9c5dc0e..c934e48b4 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -14,6 +14,7 @@ import torch import torch.utils.data +from plato.config import Config from plato.trainers.strategies.base import DataLoaderStrategy, TrainingContext CollateFn = Callable[[list[Any]], Any] @@ -66,12 +67,35 @@ def _local_step_stream_start( return offset % stream_length +def _enforce_diloco_full_participation_for_local_steps() -> None: + """Require DiLoCo workers to train once per outer synchronization.""" + server_type = getattr(Config().server, "type", None) + if server_type != "diloco": + return + + total_clients = int(Config().clients.total_clients) + clients_per_round = int(Config().clients.per_round) + if clients_per_round == total_clients: + return + + raise ValueError( + "DiLoCo local-step data loading requires clients.per_round to equal " + "clients.total_clients so every worker advances its local data stream " + "once per outer round." + ) + + def _apply_local_step_sampling_stream( sampler_obj, batch_size: int, context: TrainingContext ): """Advance deterministic samplers across short local-step rounds.""" local_steps_per_round = context.state.get("local_steps_per_round") - if local_steps_per_round is None or sampler_obj is None: + if local_steps_per_round is None: + return sampler_obj + + _enforce_diloco_full_participation_for_local_steps() + + if sampler_obj is None: return sampler_obj samples_per_round = int(local_steps_per_round) * int(batch_size) diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index a2063c933..fd6b4a668 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -24,6 +24,7 @@ GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, + StepLRSchedulerStrategy, ) from plato.trainers.strategies.base import ( LossCriterionStrategy, @@ -536,6 +537,85 @@ def test_local_step_sampling_warns_for_non_materializable_sampler( in caplog.text ) + def test_diloco_local_steps_require_full_client_participation( + self, simple_dataset, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + with pytest.raises( + ValueError, match="clients\\.per_round.*clients\\.total_clients" + ): + DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + def test_partial_participation_still_allowed_without_diloco_local_steps( + self, simple_dataset, temp_config + ): + Config().server.type = "fedavg" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + assert len(loader.sampler) == len(simple_dataset) + + def test_diloco_local_steps_advance_lr_scheduler_per_optimizer_step( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 4 + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 3 + + def test_non_diloco_local_steps_keep_epoch_based_lr_scheduler( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "fedavg" + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 1 + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From bcc8073728b5ab0e03ed4f28438c1d6616f43505 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 13:28:09 -0400 Subject: [PATCH 20/21] Added MNIST DiLoCo comparison configs. --- ...o_lenet5_smoke.toml => diloco_lenet5.toml} | 33 +-- .../fedavg_lenet5_diloco_comparison.toml | 63 ++++++ docs/docs/configurations/server.md | 7 +- docs/docs/development/diloco.md | 18 +- tests/integration/test_smoke_configs.py | 192 +----------------- 5 files changed, 98 insertions(+), 215 deletions(-) rename configs/MNIST/{diloco_lenet5_smoke.toml => diloco_lenet5.toml} (68%) create mode 100644 configs/MNIST/fedavg_lenet5_diloco_comparison.toml diff --git a/configs/MNIST/diloco_lenet5_smoke.toml b/configs/MNIST/diloco_lenet5.toml similarity index 68% rename from configs/MNIST/diloco_lenet5_smoke.toml rename to configs/MNIST/diloco_lenet5.toml index 0c17c2ebc..1a6e68ee0 100644 --- a/configs/MNIST/diloco_lenet5_smoke.toml +++ b/configs/MNIST/diloco_lenet5.toml @@ -4,10 +4,10 @@ type = "simple" # The total number of clients -total_clients = 2 +total_clients = 50 # The number of clients selected in each round -per_round = 2 +per_round = 50 # Should the clients compute test accuracy locally? do_test = false @@ -15,10 +15,9 @@ do_test = false [server] type = "diloco" address = "127.0.0.1" -port = 8000 +port = 8001 random_seed = 1 simulate_wall_time = true -do_test = false [server.diloco] outer_optimizer = "nesterov" @@ -29,7 +28,7 @@ apply_outer_optimizer_to = "parameters" [data] include = "mnist_iid.toml" -partition_size = 16 +partition_size = 1000 [trainer] @@ -37,21 +36,26 @@ partition_size = 16 type = "basic" # The maximum number of training rounds -rounds = 2 +rounds = 20 # The maximum number of clients running concurrently -max_concurrency = 1 +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.97 # The machine learning model model_name = "lenet5" -# DiLoCo local work H, counted in optimizer steps. -local_steps_per_round = 2 +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 preserve_optimizer_state = true -epochs = 1 -batch_size = 4 +# DiLoCo paper inner-optimizer settings. +epochs = 250 +batch_size = 512 optimizer = "AdamW" +lr_scheduler = "LambdaLR" [algorithm] @@ -64,5 +68,8 @@ type = "fedavg" num_classes = 10 [parameters.optimizer] -lr = 0.001 -weight_decay = 0.0 +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml new file mode 100644 index 000000000..ecbd1c544 --- /dev/null +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -0,0 +1,63 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8002 +random_seed = 1 +simulate_wall_time = true + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.97 + +# The machine learning model +model_name = "lenet5" + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +epochs = 5 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 42d0bad99..cef578eb9 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -147,13 +147,14 @@ `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. - A runnable smoke configuration is available at `configs/MNIST/diloco_lenet5_smoke.toml`: + Runnable comparison configurations are available for MNIST/LeNet and CIFAR-10/ResNet-18: ```bash - uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml + uv run python plato.py --config configs/MNIST/diloco_lenet5.toml + uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml ``` - The smoke configuration validates DiLoCo mechanics in Plato; it is not a C4/model/pretraining reproduction of the DiLoCo paper. + These configurations validate DiLoCo mechanics in Plato; they are not C4/model/pretraining reproductions of the DiLoCo paper. !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md index b66eae9ad..f4b8ce405 100644 --- a/docs/docs/development/diloco.md +++ b/docs/docs/development/diloco.md @@ -8,19 +8,21 @@ training loop inside Plato's federated runtime. It does not mean reproducing the paper's exact C4 dataset, model scale, tokenizer, hardware topology, pretraining duration, or final benchmark numbers. -## Smoke Configuration +## Example Configurations -Plato includes a small MNIST/LeNet smoke configuration for checking the DiLoCo -mechanics: +Plato includes MNIST/LeNet and CIFAR-10/ResNet-18 comparison configurations +for checking DiLoCo against matched FedAvg runs: ```bash -uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml +uv run python plato.py --config configs/MNIST/diloco_lenet5.toml +uv run python plato.py --config configs/MNIST/fedavg_lenet5_diloco_comparison.toml +uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml +uv run python plato.py --config configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml ``` -This smoke run validates configuration loading, DiLoCo server selection, local -optimizer-step work, client-local optimizer-state persistence, and server-side -outer aggregation. It is intentionally tiny and does not reproduce the C4 -language-model pretraining setup or the paper's reported metrics. +These examples validate Plato's DiLoCo mechanics without reproducing the C4 +dataset, tokenizer, language-model scale, hardware topology, pretraining +duration, or final benchmark numbers from the paper. ## Algorithm Contract diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 8dd1271a4..e499655be 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -4,8 +4,7 @@ from __future__ import annotations -from importlib import import_module, reload -from pathlib import Path +from importlib import import_module from types import SimpleNamespace from typing import cast @@ -18,11 +17,8 @@ async_run, build_minimal_config, configure_environment, - configure_environment_from_path, ) -REPO_ROOT = Path(__file__).resolve().parents[2] - class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" @@ -48,56 +44,6 @@ def get_test_set(self): return self._test -LOCAL_STATE_PAYLOAD_KEYS = { - "optimizer_state", - "scheduler_state", - "trainer_state", - "local_metadata", - "metadata", - "global_step", - "local_optimizer_steps", - "_optimizer_state_input_filename", - "_optimizer_state_output_filename", -} - - -def _model_weight_payload(payload, model): - """Assert that a client payload contains model weights only.""" - model_state = model.state_dict() - - assert isinstance(payload, dict) - assert set(payload) == set(model_state) - assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) - assert all(torch.is_tensor(value) for value in payload.values()) - - -def _shifted_payload(weights, amount): - """Build a fake client model payload shifted from the server baseline.""" - shifted = {} - for name, value in weights.items(): - shifted[name] = value.clone() - if torch.is_floating_point(shifted[name]): - shifted[name] = shifted[name] + amount - return shifted - - -def _server_update(client_id, payload): - """Build a minimal server update carrying model weights.""" - return SimpleNamespace( - client_id=client_id, - report=SimpleNamespace( - client_id=client_id, - num_samples=1, - accuracy=0.5, - processing_time=0.1, - comm_time=0.1, - training_time=0.1, - type="weights", - ), - payload=payload, - ) - - @pytest.mark.integration def test_fedavg_lenet5_smoke(monkeypatch): """End-to-end smoke test for a minimal FedAvg run.""" @@ -151,142 +97,6 @@ def test_fedavg_lenet5_smoke(monkeypatch): async_run(server._process_reports()) assert server.accuracy >= 0 - -@pytest.mark.integration -def test_diloco_lenet5_smoke_config_contract_loads(): - """Smoke config should load the faithful DiLoCo contract.""" - config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" - - with configure_environment_from_path(config_path) as config: - assert config.server.type == "diloco" - assert config.algorithm.type == "fedavg" - assert config.trainer.local_steps_per_round == 2 - assert config.trainer.preserve_optimizer_state is True - assert config.trainer.optimizer == "AdamW" - assert config.server.diloco.outer_optimizer == "nesterov" - assert config.server.diloco.outer_learning_rate == 0.7 - assert config.server.diloco.outer_momentum == 0.9 - assert config.server.diloco.aggregation_weighting == "uniform" - assert config.server.diloco.apply_outer_optimizer_to == "parameters" - - server_registry = reload(import_module("plato.servers.registry")) - diloco_server = import_module("plato.servers.diloco") - diloco_aggregation = import_module("plato.servers.strategies.aggregation") - - server = server_registry.get() - - assert isinstance(server, diloco_server.Server) - assert isinstance( - server.aggregation_strategy, - diloco_aggregation.DiLoCoAggregationStrategy, - ) - - -@pytest.mark.integration -def test_diloco_lenet5_smoke_config_runs_faithful_path(monkeypatch): - """Exact DiLoCo smoke config exercises local work and outer aggregation.""" - config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" - - with configure_environment_from_path(config_path) as config: - datasources_registry = import_module("plato.datasources.registry") - processor_registry = import_module("plato.processors.registry") - server_registry = reload(import_module("plato.servers.registry")) - client_mod = import_module("plato.clients.simple") - config_mod = import_module("plato.config") - diloco_server = import_module("plato.servers.diloco") - fedavg_algorithm = import_module("plato.algorithms.fedavg") - diloco_aggregation = import_module("plato.servers.strategies.aggregation") - - fake_datasource = MNISTSmokeDatasource(train_size=32, test_size=4) - monkeypatch.setattr( - datasources_registry, - "get", - lambda *args, **kwargs: fake_datasource, - ) - monkeypatch.setattr( - processor_registry, - "get", - lambda *args, **kwargs: (None, None), - ) - - server = server_registry.get() - server.configure() - - assert isinstance(server, diloco_server.Server) - assert isinstance(server.algorithm, fedavg_algorithm.Algorithm) - assert isinstance( - server.aggregation_strategy, - diloco_aggregation.DiLoCoAggregationStrategy, - ) - assert config.server.type == "diloco" - assert config.algorithm.type == "fedavg" - assert config.trainer.local_steps_per_round == 2 - assert config.trainer.preserve_optimizer_state is True - assert config.data.sampler == "iid" - - client = client_mod.Client() - client.client_id = 1 - client._context.client_id = 1 - client.current_round = 1 - client._context.current_round = 1 - client._load_data() - client.configure() - client._allocate_data() - client._load_payload(server.algorithm.extract_weights()) - - train_config = config.trainer._asdict() - train_config["run_id"] = config_mod.Config.params["run_id"] - client.trainer.current_round = client.current_round - client.trainer.train_model(train_config, client.trainset, client.sampler) - payload = client.algorithm.extract_weights() - - assert client.sampler.num_samples() == config.data.partition_size - assert client.trainer.context.state["local_steps_per_round"] == 2 - assert client.trainer.context.state["local_optimizer_steps"] == 2 - assert client.trainer.current_epoch == 1 - assert client.client_id in client.trainer._preserved_optimizer_states - assert client.trainer._preserved_optimizer_states[client.client_id][ - "optimizer_state" - ]["state"] - _model_weight_payload(payload, client.trainer.model) - - # Small-H mid-epoch stopping and round-aware sampler streaming are covered - # in TestComposableTrainerLocalSteps; this integration path verifies the - # exact smoke config enables those runtime flags with the supported sampler. - baseline = server.algorithm.extract_weights() - trainable_name = next(iter(dict(server.trainer.model.named_parameters()))) - server.updates = [ - _server_update(1, _shifted_payload(baseline, 1.0)), - _server_update(2, _shifted_payload(baseline, 3.0)), - ] - server.current_round = 1 - server.context.current_round = 1 - - delta_calls = [] - aggregate_deltas = server.aggregation_strategy.aggregate_deltas - - async def record_delta_aggregation(updates, deltas_received, context): - delta_calls.append((updates, deltas_received)) - return await aggregate_deltas(updates, deltas_received, context) - - monkeypatch.setattr( - server.aggregation_strategy, - "aggregate_deltas", - record_delta_aggregation, - ) - - async_run(server._process_reports()) - - updated = server.algorithm.extract_weights() - ordinary_fedavg_value = baseline[trainable_name] + 2.0 - faithful_diloco_value = baseline[trainable_name] + 2.0 * 1.9 * 0.7 - - assert len(delta_calls) == 1 - assert len(delta_calls[0][1]) == 2 - assert not torch.allclose(updated[trainable_name], ordinary_fedavg_value) - assert torch.allclose(updated[trainable_name], faithful_diloco_value) - - @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" From 6ffb475f423891cfde2c9c252d46d0e1752960b6 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 13:53:44 -0400 Subject: [PATCH 21/21] Aligned DiLoCo comparison budgets. --- configs/CIFAR10/diloco_resnet18.toml | 6 +++--- configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml | 7 ++++--- configs/MNIST/diloco_lenet5.toml | 6 +++--- configs/MNIST/fedavg_lenet5_diloco_comparison.toml | 9 ++++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml index 70193dee5..ed407000c 100644 --- a/configs/CIFAR10/diloco_resnet18.toml +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -49,15 +49,15 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.8 +target_accuracy = 0.9 # Number of local optimizer steps per DiLoCo synchronization. local_steps_per_round = 500 preserve_optimizer_state = true # DiLoCo paper inner-optimizer settings. -epochs = 250 -batch_size = 512 +epochs = 5 +batch_size = 10 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml index 329572e8d..26f32d0ce 100644 --- a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -41,11 +41,12 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.8 +target_accuracy = 0.9 -# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# Match the original FedAvg local training shape while keeping 500 optimizer +# steps per round, equal to DiLoCo's H. epochs = 5 -batch_size = 512 +batch_size = 10 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/MNIST/diloco_lenet5.toml b/configs/MNIST/diloco_lenet5.toml index 1a6e68ee0..53eff9305 100644 --- a/configs/MNIST/diloco_lenet5.toml +++ b/configs/MNIST/diloco_lenet5.toml @@ -42,7 +42,7 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.97 +target_accuracy = 0.99 # The machine learning model model_name = "lenet5" @@ -52,8 +52,8 @@ local_steps_per_round = 500 preserve_optimizer_state = true # DiLoCo paper inner-optimizer settings. -epochs = 250 -batch_size = 512 +epochs = 5 +batch_size = 32 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml index ecbd1c544..e223915bb 100644 --- a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -28,20 +28,23 @@ partition_size = 1000 type = "basic" # The maximum number of training rounds -rounds = 20 +rounds = 63 # The maximum number of clients running concurrently max_concurrency = 7 # The target accuracy -target_accuracy = 0.97 +target_accuracy = 0.99 # The machine learning model model_name = "lenet5" # Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per +# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching +# DiLoCo's 20 * H=500 = 10,000-step total budget. epochs = 5 -batch_size = 512 +batch_size = 32 optimizer = "AdamW" lr_scheduler = "LambdaLR"