From b27e6dc23489eea3b5c471dbd9d51ecd546973fc Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 01:06:33 -0400 Subject: [PATCH 1/7] Implemented the FedDF server-aggregation example. Add a new FedDF server-aggregation example that transports proxy-set logits from clients and distills the next global model on the server. - add the FedDF client, server, aggregation strategy, algorithm, and shared proxy helpers - add a runnable MNIST/LeNet-5 example config and documentation mapping the code to the paper - cover the logits transport and server-side distillation paths with targeted tests Validation: - uv run pytest -- tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py - uv run pytest -- tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py tests/clients/test_fednova_strategy.py tests/servers/test_fedavg_strategy.py - uv run ruff check examples/server_aggregation/feddf tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py - git diff --check - uv run python import/config smoke check for the FedDF example --- .../1. Server Aggregation Algorithms.md | 28 +++ examples/server_aggregation/feddf/feddf.py | 19 ++ .../feddf/feddf_MNIST_lenet5.toml | 64 +++++++ .../feddf/feddf_algorithm.py | 122 +++++++++++++ .../server_aggregation/feddf/feddf_client.py | 103 +++++++++++ .../server_aggregation/feddf/feddf_server.py | 37 ++++ .../feddf/feddf_server_strategy.py | 106 ++++++++++++ .../server_aggregation/feddf/feddf_utils.py | 92 ++++++++++ tests/clients/test_feddf_strategy.py | 70 ++++++++ tests/servers/test_feddf_server_strategy.py | 163 ++++++++++++++++++ 10 files changed, 804 insertions(+) create mode 100644 examples/server_aggregation/feddf/feddf.py create mode 100644 examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml create mode 100644 examples/server_aggregation/feddf/feddf_algorithm.py create mode 100644 examples/server_aggregation/feddf/feddf_client.py create mode 100644 examples/server_aggregation/feddf/feddf_server.py create mode 100644 examples/server_aggregation/feddf/feddf_server_strategy.py create mode 100644 examples/server_aggregation/feddf/feddf_utils.py create mode 100644 tests/clients/test_feddf_strategy.py create mode 100644 tests/servers/test_feddf_server_strategy.py diff --git a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md index ebf721219..fa98b67fb 100644 --- a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md +++ b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md @@ -56,6 +56,34 @@ Key configuration parameters: --- +### FedDF + +FedDF (Federated Distillation and Fusion) replaces direct parameter averaging with server-side distillation on a proxy set. Clients still perform standard local training, but they return teacher logits on a shared unlabeled proxy subset instead of ordinary weight deltas. The server aggregates those logits and distills their ensemble into the next global model using temperature-scaled soft targets. + +```bash +cd examples/server_aggregation/feddf/ +uv run feddf.py -c feddf_MNIST_lenet5.toml +``` + +Key configuration parameters: + +- `algorithm.proxy_set_size`: Number of unlabeled proxy samples used on the server for distillation. +- `algorithm.proxy_batch_size`: Batch size for iterating through the proxy set. +- `algorithm.proxy_seed`: Seed used to select the deterministic proxy subset shared by clients and server. +- `algorithm.temperature`: Softmax temperature used to smooth teacher logits before distillation. +- `algorithm.distillation_epochs`: Number of server-side distillation passes per round. +- `algorithm.distillation_batch_size`: Batch size for the distillation optimizer. +- `algorithm.learning_rate`: Learning rate for the server-side distillation optimizer. + +**Reference:** Tao Lin, Lingjing Kong, Sebastian U. Stich, Martin Jaggi. "[Ensemble Distillation for Robust Model Fusion in Federated Learning](https://arxiv.org/abs/2006.07242)," arXiv:2006.07242, 2020. + +!!! note "Alignment with the paper" + The module split follows the FedDF workflow directly: `feddf.py` stays as a thin launcher, `feddf_client.py` performs the standard local update and then emits proxy-set logits, `feddf_server.py` wires the custom strategy and algorithm into the FedAvg server, `feddf_server_strategy.py` reconstructs the shared proxy subset and routes the logits payload through direct weight aggregation, and `feddf_algorithm.py` encapsulates the weighted-logit ensemble plus the temperature-scaled KL distillation step. + + The configuration surface above mirrors the paper’s core knobs. `proxy_set_size`, `proxy_batch_size`, and `proxy_seed` control the deterministic unlabeled proxy data used for ensemble distillation, `temperature` shapes the softened teacher distribution, and `distillation_epochs`, `distillation_batch_size`, and `learning_rate` control the server-side student optimization that replaces direct averaging. + +--- + ### MOON MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level contrastive regularizer. Each client augments the shared model with a projection head, clones the incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints as negatives. The server still performs sample-weighted averaging but records a short history of global states for downstream analysis or warm restarts. diff --git a/examples/server_aggregation/feddf/feddf.py b/examples/server_aggregation/feddf/feddf.py new file mode 100644 index 000000000..ab6701601 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf.py @@ -0,0 +1,19 @@ +""" +Entry point for running the FedDF server aggregation example. +""" + +from __future__ import annotations + +import feddf_client +import feddf_server + + +def main(): + """Launch a Plato training session with the FedDF algorithm.""" + client = feddf_client.create_client() + server = feddf_server.Server() + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml b/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml new file mode 100644 index 000000000..11b2b6c68 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml @@ -0,0 +1,64 @@ +[clients] +type = "simple" +total_clients = 100 +per_round = 10 +random_seed = 1 +do_test = false +speed_simulation = true +sleep_simulation = true +avg_training_time = 20 + +[clients.simulation_distribution] +distribution = "normal" +mean = 10 +sd = 3 + +[server] +address = "127.0.0.1" +port = 8000 +synchronous = false +simulate_wall_time = true +minimum_clients_aggregated = 6 +staleness_bound = 10 +random_seed = 1 + +[data] +datasource = "Torchvision" +dataset_name = "MNIST" +download = true +data_path = "data" +partition_size = 200 +sampler = "noniid" +concentration = 0.5 +random_seed = 1 + +[trainer] +rounds = 20 +type = "basic" +max_concurrency = 4 +target_accuracy = 1.0 +epochs = 5 +batch_size = 64 +optimizer = "SGD" +model_name = "lenet5" + +[algorithm] +type = "fedavg" +proxy_set_size = 2048 +proxy_batch_size = 128 +proxy_seed = 1 +temperature = 2.0 +distillation_epochs = 5 +distillation_batch_size = 64 +learning_rate = 0.001 + +[parameters] + +[parameters.optimizer] +lr = 0.01 +momentum = 0.9 +weight_decay = 0.0 + +[results] +result_path = "results/MNIST_lenet5/FedDF" +types = "round, accuracy, elapsed_time, comm_time, round_time" diff --git a/examples/server_aggregation/feddf/feddf_algorithm.py b/examples/server_aggregation/feddf/feddf_algorithm.py new file mode 100644 index 000000000..96db9cd42 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_algorithm.py @@ -0,0 +1,122 @@ +"""FedDF-specific helpers for ensemble distillation on the server.""" + +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from feddf_utils import extract_batch_inputs, unwrap_model_outputs +from torch.utils.data import DataLoader, Dataset, TensorDataset + +from plato.algorithms import fedavg + + +class Algorithm(fedavg.Algorithm): + """Algorithm helpers for aggregating logits and distilling the student.""" + + @staticmethod + def aggregate_teacher_logits( + updates, + payloads: Sequence[Mapping[str, torch.Tensor]], + ) -> torch.Tensor: + """Compute a sample-weighted ensemble of client logits.""" + if not payloads: + raise ValueError("FedDF requires at least one logits payload.") + + first_logits = payloads[0].get("logits") + if not isinstance(first_logits, torch.Tensor): + raise TypeError("FedDF payloads must include a 'logits' tensor.") + + total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates) + if total_samples <= 0: + total_samples = len(payloads) + + aggregated = torch.zeros_like(first_logits, dtype=torch.float32) + + for update, payload in zip(updates, payloads): + logits = payload.get("logits") + if not isinstance(logits, torch.Tensor): + raise TypeError("FedDF payloads must include a 'logits' tensor.") + if logits.shape != first_logits.shape: + raise ValueError( + "FedDF client logits must share the same proxy-set shape." + ) + + weight = getattr(update.report, "num_samples", 0) / total_samples + if total_samples == len(payloads): + weight = 1 / len(payloads) + + aggregated += logits.detach().float() * weight + + return aggregated + + def distill_weights( + self, + baseline_weights: Mapping[str, torch.Tensor], + teacher_logits: torch.Tensor, + proxy_dataset: Dataset, + *, + temperature: float, + distillation_epochs: int, + distillation_batch_size: int, + distillation_learning_rate: float, + ) -> OrderedDict[str, torch.Tensor]: + """Distill the server model on proxy inputs using ensemble logits.""" + if len(proxy_dataset) != len(teacher_logits): + raise ValueError( + "FedDF proxy samples and teacher logits must have matching lengths." + ) + + trainer = self.require_trainer() + model = self.require_model() + device = torch.device(getattr(trainer, "device", "cpu")) + + self.load_weights(baseline_weights) + + inputs = [] + for example in proxy_dataset: + inputs.append(extract_batch_inputs(example)) + + proxy_inputs = torch.stack(inputs) + distillation_dataset = TensorDataset(proxy_inputs, teacher_logits.detach().cpu()) + dataloader = DataLoader( + distillation_dataset, + batch_size=distillation_batch_size, + shuffle=False, + ) + + was_training = model.training + model.to(device) + model.train() + + optimizer = torch.optim.SGD( + model.parameters(), + lr=distillation_learning_rate, + ) + + for _ in range(distillation_epochs): + for batch_inputs, batch_logits in dataloader: + batch_inputs = batch_inputs.to(device) + batch_logits = batch_logits.to(device) + teacher_probs = torch.softmax(batch_logits / temperature, dim=1) + + optimizer.zero_grad() + student_logits = unwrap_model_outputs(model(batch_inputs)) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + loss = ( + F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") + * temperature + * temperature + ) + loss.backward() + optimizer.step() + + if not was_training: + model.eval() + + return OrderedDict( + (name, tensor.detach().cpu().clone()) + for name, tensor in model.state_dict().items() + ) diff --git a/examples/server_aggregation/feddf/feddf_client.py b/examples/server_aggregation/feddf/feddf_client.py new file mode 100644 index 000000000..9193ee1ed --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_client.py @@ -0,0 +1,103 @@ +"""Client implementation for the FedDF server aggregation example.""" + +from __future__ import annotations + +from feddf_utils import ( + collect_proxy_logits, + resolve_algorithm_value, + select_proxy_subset, +) + +from plato.clients import simple +from plato.clients.strategies.defaults import DefaultTrainingStrategy + + +class FedDFTrainingStrategy(DefaultTrainingStrategy): + """Train locally, then emit teacher logits on a shared proxy set.""" + + def __init__( + self, + *, + proxy_size: int | None = None, + proxy_batch_size: int | None = None, + proxy_seed: int | None = None, + ) -> None: + super().__init__() + self.proxy_size = proxy_size + self.proxy_batch_size = proxy_batch_size + self.proxy_seed = proxy_seed + + async def train(self, context): + report, _ = await super().train(context) + + datasource = getattr(context, "datasource", None) + if datasource is None: + raise RuntimeError("FedDF requires a datasource to resolve proxy samples.") + + trainer = getattr(context, "trainer", None) + if trainer is None or getattr(trainer, "model", None) is None: + raise RuntimeError("FedDF requires a trainer with a model for logits.") + + proxy_size = resolve_algorithm_value("proxy_set_size", self.proxy_size, 512) + proxy_batch_size = resolve_algorithm_value( + "proxy_batch_size", self.proxy_batch_size, 128 + ) + proxy_seed = resolve_algorithm_value("proxy_seed", self.proxy_seed, 1) + + proxy_dataset, proxy_indices = select_proxy_subset( + datasource.get_test_set(), + size=proxy_size, + seed=proxy_seed, + ) + logits = collect_proxy_logits( + trainer.model, + proxy_dataset, + batch_size=proxy_batch_size, + device=getattr(trainer, "device", "cpu"), + ) + + context.state["feddf_proxy_indices"] = proxy_indices + report.payload_type = "feddf_logits" + report.proxy_size = len(proxy_indices) + + return report, {"logits": logits} + + +def create_client( + *, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + trainer_callbacks=None, + proxy_size: int | None = None, + proxy_batch_size: int | None = None, + proxy_seed: int | None = None, +): + """Build a client configured to emit FedDF proxy-set logits.""" + client = simple.Client( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + trainer_callbacks=trainer_callbacks, + ) + + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=FedDFTrainingStrategy( + proxy_size=proxy_size, + proxy_batch_size=proxy_batch_size, + proxy_seed=proxy_seed, + ), + reporting_strategy=client.reporting_strategy, + communication_strategy=client.communication_strategy, + ) + + return client + + +Client = create_client diff --git a/examples/server_aggregation/feddf/feddf_server.py b/examples/server_aggregation/feddf/feddf_server.py new file mode 100644 index 000000000..4caa960a1 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_server.py @@ -0,0 +1,37 @@ +"""Server wrapper for the FedDF server aggregation example.""" + +from __future__ import annotations + +from feddf_algorithm import Algorithm as FedDFAlgorithm +from feddf_server_strategy import FedDFAggregationStrategy + +from plato.servers import fedavg + + +class Server(fedavg.Server): + """A federated learning server using FedDF distillation aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = FedDFAggregationStrategy() + + selected_algorithm = algorithm or FedDFAlgorithm + + super().__init__( + model=model, + datasource=datasource, + algorithm=selected_algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) diff --git a/examples/server_aggregation/feddf/feddf_server_strategy.py b/examples/server_aggregation/feddf/feddf_server_strategy.py new file mode 100644 index 000000000..70a6f439a --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_server_strategy.py @@ -0,0 +1,106 @@ +"""FedDF aggregation strategy using server-side proxy-set distillation.""" + +from __future__ import annotations + +from typing import Mapping + +from feddf_utils import resolve_algorithm_value, select_proxy_subset +from torch.utils.data import Dataset + +from plato.datasources import registry as datasources_registry +from plato.servers.strategies.base import AggregationStrategy, ServerContext + + +class FedDFAggregationStrategy(AggregationStrategy): + """Aggregate client logits and distill them into the global student.""" + + def __init__( + self, + *, + proxy_set_size: int | None = None, + proxy_seed: int | None = None, + temperature: float | None = None, + distillation_epochs: int | None = None, + distillation_batch_size: int | None = None, + distillation_learning_rate: float | None = None, + ) -> None: + super().__init__() + self.proxy_set_size = proxy_set_size + self.proxy_seed = proxy_seed + self.temperature = temperature + self.distillation_epochs = distillation_epochs + self.distillation_batch_size = distillation_batch_size + self.distillation_learning_rate = distillation_learning_rate + + async def aggregate_deltas(self, updates, deltas_received, context: ServerContext): + """FedDF does not aggregate parameter deltas.""" + raise NotImplementedError("FedDF uses aggregate_weights with logits payloads.") + + def _resolve_proxy_dataset(self, context: ServerContext) -> Dataset: + """Construct or reuse the deterministic proxy subset.""" + cached = context.state.get("feddf_proxy_dataset") + if cached is not None: + return cached + + server = getattr(context, "server", None) + if server is None: + raise RuntimeError("FedDF requires the server in strategy context.") + + datasource = getattr(server, "datasource", None) + if datasource is None: + custom_datasource = getattr(server, "custom_datasource", None) + if custom_datasource is not None: + datasource = custom_datasource() + else: + datasource = datasources_registry.get(client_id=0) + server.datasource = datasource + + proxy_set_size = resolve_algorithm_value( + "proxy_set_size", self.proxy_set_size, 512 + ) + proxy_seed = resolve_algorithm_value("proxy_seed", self.proxy_seed, 1) + proxy_dataset, proxy_indices = select_proxy_subset( + datasource.get_test_set(), + size=proxy_set_size, + seed=proxy_seed, + ) + + context.state["feddf_proxy_dataset"] = proxy_dataset + context.state["feddf_proxy_indices"] = proxy_indices + return proxy_dataset + + async def aggregate_weights( + self, + updates, + baseline_weights: Mapping, + weights_received, + context: ServerContext, + ): + """Distill the global student from client logits on the proxy set.""" + algorithm = getattr(context, "algorithm", None) + if algorithm is None: + raise RuntimeError("FedDF requires an algorithm instance in context.") + + proxy_dataset = self._resolve_proxy_dataset(context) + teacher_logits = algorithm.aggregate_teacher_logits(updates, weights_received) + + temperature = resolve_algorithm_value("temperature", self.temperature, 2.0) + distillation_epochs = resolve_algorithm_value( + "distillation_epochs", self.distillation_epochs, 5 + ) + distillation_batch_size = resolve_algorithm_value( + "distillation_batch_size", self.distillation_batch_size, 64 + ) + distillation_learning_rate = resolve_algorithm_value( + "learning_rate", self.distillation_learning_rate, 0.001 + ) + + return algorithm.distill_weights( + baseline_weights, + teacher_logits, + proxy_dataset, + temperature=temperature, + distillation_epochs=distillation_epochs, + distillation_batch_size=distillation_batch_size, + distillation_learning_rate=distillation_learning_rate, + ) diff --git a/examples/server_aggregation/feddf/feddf_utils.py b/examples/server_aggregation/feddf/feddf_utils.py new file mode 100644 index 000000000..5a69d1062 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_utils.py @@ -0,0 +1,92 @@ +"""Shared proxy-set helpers for the FedDF server aggregation example.""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch.utils.data import DataLoader, Dataset, Subset + +from plato.config import Config + + +def resolve_algorithm_value(name: str, explicit_value: Any, default: Any) -> Any: + """Resolve an example parameter from the constructor or config file.""" + if explicit_value is not None: + return explicit_value + + algorithm_cfg = getattr(Config(), "algorithm", None) + if algorithm_cfg is None: + return default + + return getattr(algorithm_cfg, name, default) + + +def select_proxy_subset( + dataset: Dataset, + *, + size: int, + seed: int, +) -> tuple[Subset, list[int]]: + """Build a deterministic subset of the shared proxy dataset.""" + total_examples = len(dataset) + if total_examples == 0: + raise ValueError("FedDF proxy dataset is empty.") + + subset_size = min(size, total_examples) + generator = torch.Generator().manual_seed(seed) + indices = torch.randperm(total_examples, generator=generator)[:subset_size].tolist() + indices.sort() + + return Subset(dataset, indices), indices + + +def unwrap_model_outputs(outputs: Any) -> torch.Tensor: + """Normalise model outputs to a logits tensor.""" + if isinstance(outputs, (tuple, list)): + return outputs[0] + if not isinstance(outputs, torch.Tensor): + raise TypeError( + "FedDF expects the model forward pass to return a tensor or " + f"tensor-like tuple, received {type(outputs).__name__}." + ) + return outputs + + +def extract_batch_inputs(batch: Any) -> Any: + """Return the model input tensor from a dataset batch.""" + if isinstance(batch, (tuple, list)): + return batch[0] + if isinstance(batch, dict): + for key in ("input", "inputs", "image", "images", "x"): + if key in batch: + return batch[key] + return batch + + +def collect_proxy_logits( + model: torch.nn.Module, + proxy_dataset: Dataset, + *, + batch_size: int, + device: torch.device | str, +) -> torch.Tensor: + """Evaluate a model on the proxy set and return detached logits.""" + was_training = model.training + device = torch.device(device) + dataloader = DataLoader(proxy_dataset, batch_size=batch_size, shuffle=False) + logits: list[torch.Tensor] = [] + + model.to(device) + model.eval() + + with torch.no_grad(): + for batch in dataloader: + inputs = extract_batch_inputs(batch).to(device) + batch_logits = unwrap_model_outputs(model(inputs)) + logits.append(batch_logits.detach().cpu()) + + if was_training: + model.train() + + return torch.cat(logits, dim=0) diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py new file mode 100644 index 000000000..12ff5218a --- /dev/null +++ b/tests/clients/test_feddf_strategy.py @@ -0,0 +1,70 @@ +"""Tests for the FedDF example training strategy.""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import torch + +from tests.test_utils.fakes import FakeDatasource, FakeModel + +_TESTS_ROOT = Path(__file__).resolve().parent +_FEDDF_DIR = ( + _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" +) +if str(_FEDDF_DIR) not in sys.path: + sys.path.insert(0, str(_FEDDF_DIR)) + +_FEDDF_CLIENT_PATH = _FEDDF_DIR / "feddf_client.py" +_FEDDF_SPEC = importlib.util.spec_from_file_location( + "feddf_client_module", _FEDDF_CLIENT_PATH +) +if _FEDDF_SPEC is None: + raise RuntimeError(f"Unable to load spec for {_FEDDF_CLIENT_PATH}") + +feddf_client = cast(Any, importlib.util.module_from_spec(_FEDDF_SPEC)) +loader = _FEDDF_SPEC.loader +if loader is None: + raise RuntimeError(f"Loader missing for {_FEDDF_CLIENT_PATH}") +loader.exec_module(feddf_client) + + +def test_feddf_training_strategy_returns_teacher_logits(temp_config): + """FedDF clients should send proxy-set logits instead of model weights.""" + strategy = feddf_client.FedDFTrainingStrategy( + proxy_size=3, + proxy_batch_size=2, + proxy_seed=11, + ) + context = SimpleNamespace( + client_id=1, + current_round=1, + datasource=FakeDatasource(test_length=5), + trainer=SimpleNamespace(model=FakeModel(), device="cpu"), + state={}, + ) + + mock_report = SimpleNamespace(num_samples=8) + async_mock = AsyncMock(return_value=(mock_report, {"weights": torch.ones(1)})) + + with patch.object( + feddf_client.DefaultTrainingStrategy, + "train", + new=async_mock, + ) as mock_train: + report, payload = asyncio.run(strategy.train(context)) + + mock_train.assert_awaited_once() + assert report is mock_report + assert getattr(report, "payload_type") == "feddf_logits" + assert getattr(report, "proxy_size") == 3 + assert "logits" in payload + assert "weights" not in payload + assert tuple(payload["logits"].shape) == (3, 2) + assert context.state["feddf_proxy_indices"] == [1, 2, 4] diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py new file mode 100644 index 000000000..f0dd9acee --- /dev/null +++ b/tests/servers/test_feddf_server_strategy.py @@ -0,0 +1,163 @@ +"""Tests for the FedDF example aggregation strategy.""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +import torch +import torch.nn.functional as F +from torch.utils.data import TensorDataset + +from plato.config import Config + +_TESTS_ROOT = Path(__file__).resolve().parent +_FEDDF_DIR = ( + _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" +) +if str(_FEDDF_DIR) not in sys.path: + sys.path.insert(0, str(_FEDDF_DIR)) + + +def _load_module(module_name: str, path: Path) -> Any: + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None: + raise RuntimeError(f"Unable to load spec for {path}") + + module = cast(Any, importlib.util.module_from_spec(spec)) + loader = spec.loader + if loader is None: + raise RuntimeError(f"Loader missing for {path}") + + loader.exec_module(module) + return module + + +feddf_algorithm = _load_module( + "feddf_algorithm_module", + _FEDDF_DIR / "feddf_algorithm.py", +) +feddf_server_strategy = _load_module( + "feddf_server_strategy_module", + _FEDDF_DIR / "feddf_server_strategy.py", +) + + +class TinyStudent(torch.nn.Module): + """Small student model used to verify server-side distillation.""" + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2, bias=False) + with torch.no_grad(): + self.linear.weight.zero_() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.linear(inputs) + + +def test_feddf_server_process_reports_distills_global_model(temp_config): + """FedDF should consume logits payloads and update the global model.""" + from plato.servers import fedavg + + Config().server.do_test = False + + trainer = SimpleNamespace(model=TinyStudent(), device="cpu") + algorithm = feddf_algorithm.Algorithm(trainer=trainer) + strategy = feddf_server_strategy.FedDFAggregationStrategy( + temperature=1.0, + distillation_epochs=80, + distillation_batch_size=2, + distillation_learning_rate=0.4, + ) + server = fedavg.Server(aggregation_strategy=strategy) + server.algorithm = algorithm + server.trainer = trainer + server.context.server = server + server.context.algorithm = algorithm + server.context.trainer = trainer + server.context.state["prng_state"] = None + + proxy_inputs = torch.tensor( + [ + [2.0, 0.0], + [0.0, 2.0], + [1.5, 0.2], + [0.2, 1.5], + ] + ) + proxy_dataset = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) + server.context.state["feddf_proxy_dataset"] = proxy_dataset + + teacher_logits_a = torch.tensor( + [ + [7.0, -7.0], + [-7.0, 7.0], + [5.5, -5.5], + [-5.5, 5.5], + ] + ) + teacher_logits_b = torch.tensor( + [ + [6.0, -6.0], + [-6.0, 6.0], + [4.5, -4.5], + [-4.5, 4.5], + ] + ) + + server.updates = [ + SimpleNamespace( + client_id=1, + report=SimpleNamespace( + num_samples=1, + accuracy=0.0, + processing_time=0.0, + comm_time=0.0, + training_time=0.0, + ), + payload={"logits": teacher_logits_a}, + ), + SimpleNamespace( + client_id=2, + report=SimpleNamespace( + num_samples=1, + accuracy=0.0, + processing_time=0.0, + comm_time=0.0, + training_time=0.0, + ), + payload={"logits": teacher_logits_b}, + ), + ] + + baseline_log_probs = torch.log_softmax(torch.zeros_like(teacher_logits_a), dim=1) + teacher_targets = torch.softmax((teacher_logits_a + teacher_logits_b) / 2, dim=1) + baseline_loss = F.kl_div( + baseline_log_probs, + teacher_targets, + reduction="batchmean", + ) + + asyncio.run(server._process_reports()) + + updated_weights = algorithm.extract_weights() + assert not torch.allclose( + updated_weights["linear.weight"], + torch.zeros_like(updated_weights["linear.weight"]), + ) + + with torch.no_grad(): + updated_logits = trainer.model(proxy_inputs) + updated_log_probs = torch.log_softmax(updated_logits, dim=1) + + distilled_loss = F.kl_div( + updated_log_probs, + teacher_targets, + reduction="batchmean", + ) + assert distilled_loss < baseline_loss From 754d8381c08eb2348b42fd5dd90068653fe43f05 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 01:15:27 -0400 Subject: [PATCH 2/7] Shared the FedDF proxy payload between server and clients. Address the DT-335 review findings by making the proxy protocol explicit instead of assuming that every client can reconstruct the same proxy subset locally. - send the server-selected proxy inputs together with the global weights in the FedDF server payload - load and reuse the shared proxy payload on the client before emitting logits - update the tests to exercise the real proxy-resolution path and document the repaired protocol Validation: - uv run pytest -- tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py - uv run pytest -- tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py tests/clients/test_fednova_strategy.py tests/servers/test_fedavg_strategy.py - uv run ruff check examples/server_aggregation/feddf tests/clients/test_feddf_strategy.py tests/servers/test_feddf_server_strategy.py - git diff --check - uv run python import/config smoke check for the FedDF example --- .../1. Server Aggregation Algorithms.md | 4 +- .../server_aggregation/feddf/feddf_client.py | 44 +++++++++---------- .../server_aggregation/feddf/feddf_server.py | 12 +++++ .../server_aggregation/feddf/feddf_utils.py | 9 ++++ tests/clients/test_feddf_strategy.py | 16 ++++--- tests/servers/test_feddf_server_strategy.py | 40 ++++++++++++----- 6 files changed, 83 insertions(+), 42 deletions(-) diff --git a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md index fa98b67fb..05167eadb 100644 --- a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md +++ b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md @@ -58,7 +58,7 @@ Key configuration parameters: ### FedDF -FedDF (Federated Distillation and Fusion) replaces direct parameter averaging with server-side distillation on a proxy set. Clients still perform standard local training, but they return teacher logits on a shared unlabeled proxy subset instead of ordinary weight deltas. The server aggregates those logits and distills their ensemble into the next global model using temperature-scaled soft targets. +FedDF (Federated Distillation and Fusion) replaces direct parameter averaging with server-side distillation on a proxy set. The server selects a deterministic unlabeled proxy subset, ships those proxy inputs alongside the current global weights, and each client returns teacher logits on that shared payload instead of ordinary weight deltas. The server then aggregates those logits and distills their ensemble into the next global model using temperature-scaled soft targets. ```bash cd examples/server_aggregation/feddf/ @@ -78,7 +78,7 @@ Key configuration parameters: **Reference:** Tao Lin, Lingjing Kong, Sebastian U. Stich, Martin Jaggi. "[Ensemble Distillation for Robust Model Fusion in Federated Learning](https://arxiv.org/abs/2006.07242)," arXiv:2006.07242, 2020. !!! note "Alignment with the paper" - The module split follows the FedDF workflow directly: `feddf.py` stays as a thin launcher, `feddf_client.py` performs the standard local update and then emits proxy-set logits, `feddf_server.py` wires the custom strategy and algorithm into the FedAvg server, `feddf_server_strategy.py` reconstructs the shared proxy subset and routes the logits payload through direct weight aggregation, and `feddf_algorithm.py` encapsulates the weighted-logit ensemble plus the temperature-scaled KL distillation step. + The module split follows the FedDF workflow directly: `feddf.py` stays as a thin launcher, `feddf_server.py` packages the current global weights together with the shared proxy inputs, `feddf_client.py` performs the standard local update and then emits logits on that server-supplied proxy payload, `feddf_server_strategy.py` resolves the deterministic proxy subset and routes the logits payload through direct weight aggregation, and `feddf_algorithm.py` encapsulates the weighted-logit ensemble plus the temperature-scaled KL distillation step. The configuration surface above mirrors the paper’s core knobs. `proxy_set_size`, `proxy_batch_size`, and `proxy_seed` control the deterministic unlabeled proxy data used for ensemble distillation, `temperature` shapes the softened teacher distribution, and `distillation_epochs`, `distillation_batch_size`, and `learning_rate` control the server-side student optimization that replaces direct averaging. diff --git a/examples/server_aggregation/feddf/feddf_client.py b/examples/server_aggregation/feddf/feddf_client.py index 9193ee1ed..b0fcf6a59 100644 --- a/examples/server_aggregation/feddf/feddf_client.py +++ b/examples/server_aggregation/feddf/feddf_client.py @@ -5,8 +5,8 @@ from feddf_utils import ( collect_proxy_logits, resolve_algorithm_value, - select_proxy_subset, ) +from torch.utils.data import TensorDataset from plato.clients import simple from plato.clients.strategies.defaults import DefaultTrainingStrategy @@ -18,47 +18,47 @@ class FedDFTrainingStrategy(DefaultTrainingStrategy): def __init__( self, *, - proxy_size: int | None = None, proxy_batch_size: int | None = None, - proxy_seed: int | None = None, ) -> None: super().__init__() - self.proxy_size = proxy_size self.proxy_batch_size = proxy_batch_size - self.proxy_seed = proxy_seed + + def load_payload(self, context, server_payload) -> None: + """Load model weights and cache the shared proxy inputs from the server.""" + if not isinstance(server_payload, dict): + raise TypeError("FedDF expects a dictionary payload from the server.") + + if "weights" not in server_payload or "proxy_inputs" not in server_payload: + raise KeyError( + "FedDF server payload must include 'weights' and 'proxy_inputs'." + ) + + context.state["feddf_proxy_inputs"] = server_payload["proxy_inputs"] + super().load_payload(context, server_payload["weights"]) async def train(self, context): report, _ = await super().train(context) - datasource = getattr(context, "datasource", None) - if datasource is None: - raise RuntimeError("FedDF requires a datasource to resolve proxy samples.") - trainer = getattr(context, "trainer", None) if trainer is None or getattr(trainer, "model", None) is None: raise RuntimeError("FedDF requires a trainer with a model for logits.") - proxy_size = resolve_algorithm_value("proxy_set_size", self.proxy_size, 512) + proxy_inputs = context.state.get("feddf_proxy_inputs") + if proxy_inputs is None: + raise RuntimeError("FedDF requires shared proxy inputs from the server.") + proxy_batch_size = resolve_algorithm_value( "proxy_batch_size", self.proxy_batch_size, 128 ) - proxy_seed = resolve_algorithm_value("proxy_seed", self.proxy_seed, 1) - - proxy_dataset, proxy_indices = select_proxy_subset( - datasource.get_test_set(), - size=proxy_size, - seed=proxy_seed, - ) logits = collect_proxy_logits( trainer.model, - proxy_dataset, + TensorDataset(proxy_inputs), batch_size=proxy_batch_size, device=getattr(trainer, "device", "cpu"), ) - context.state["feddf_proxy_indices"] = proxy_indices report.payload_type = "feddf_logits" - report.proxy_size = len(proxy_indices) + report.proxy_size = len(proxy_inputs) return report, {"logits": logits} @@ -71,9 +71,7 @@ def create_client( trainer=None, callbacks=None, trainer_callbacks=None, - proxy_size: int | None = None, proxy_batch_size: int | None = None, - proxy_seed: int | None = None, ): """Build a client configured to emit FedDF proxy-set logits.""" client = simple.Client( @@ -89,9 +87,7 @@ def create_client( lifecycle_strategy=client.lifecycle_strategy, payload_strategy=client.payload_strategy, training_strategy=FedDFTrainingStrategy( - proxy_size=proxy_size, proxy_batch_size=proxy_batch_size, - proxy_seed=proxy_seed, ), reporting_strategy=client.reporting_strategy, communication_strategy=client.communication_strategy, diff --git a/examples/server_aggregation/feddf/feddf_server.py b/examples/server_aggregation/feddf/feddf_server.py index 4caa960a1..192327b8d 100644 --- a/examples/server_aggregation/feddf/feddf_server.py +++ b/examples/server_aggregation/feddf/feddf_server.py @@ -4,6 +4,7 @@ from feddf_algorithm import Algorithm as FedDFAlgorithm from feddf_server_strategy import FedDFAggregationStrategy +from feddf_utils import stack_proxy_inputs from plato.servers import fedavg @@ -35,3 +36,14 @@ def __init__( aggregation_strategy=aggregation_strategy, client_selection_strategy=client_selection_strategy, ) + + def customize_server_payload(self, payload): + """Send weights together with the shared proxy inputs for FedDF.""" + proxy_dataset = self.aggregation_strategy._resolve_proxy_dataset(self.context) + proxy_inputs = stack_proxy_inputs(proxy_dataset) + self.context.state["feddf_proxy_inputs"] = proxy_inputs + + return { + "weights": payload, + "proxy_inputs": proxy_inputs, + } diff --git a/examples/server_aggregation/feddf/feddf_utils.py b/examples/server_aggregation/feddf/feddf_utils.py index 5a69d1062..bd3be1c4c 100644 --- a/examples/server_aggregation/feddf/feddf_utils.py +++ b/examples/server_aggregation/feddf/feddf_utils.py @@ -90,3 +90,12 @@ def collect_proxy_logits( model.train() return torch.cat(logits, dim=0) + + +def stack_proxy_inputs(proxy_dataset: Dataset) -> torch.Tensor: + """Materialise the proxy inputs in their deterministic dataset order.""" + inputs = [] + for example in proxy_dataset: + inputs.append(extract_batch_inputs(example)) + + return torch.stack(inputs) diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py index 12ff5218a..66f9bd1b4 100644 --- a/tests/clients/test_feddf_strategy.py +++ b/tests/clients/test_feddf_strategy.py @@ -12,7 +12,7 @@ import torch -from tests.test_utils.fakes import FakeDatasource, FakeModel +from tests.test_utils.fakes import FakeModel _TESTS_ROOT = Path(__file__).resolve().parent _FEDDF_DIR = ( @@ -38,17 +38,22 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): """FedDF clients should send proxy-set logits instead of model weights.""" strategy = feddf_client.FedDFTrainingStrategy( - proxy_size=3, proxy_batch_size=2, - proxy_seed=11, ) + loaded_weights = [] context = SimpleNamespace( client_id=1, current_round=1, - datasource=FakeDatasource(test_length=5), + algorithm=SimpleNamespace(load_weights=lambda weights: loaded_weights.append(weights)), trainer=SimpleNamespace(model=FakeModel(), device="cpu"), state={}, ) + proxy_inputs = torch.randn(3, 4) + inbound_payload = { + "weights": {"linear.weight": torch.ones(2, 4)}, + "proxy_inputs": proxy_inputs, + } + strategy.load_payload(context, inbound_payload) mock_report = SimpleNamespace(num_samples=8) async_mock = AsyncMock(return_value=(mock_report, {"weights": torch.ones(1)})) @@ -67,4 +72,5 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): assert "logits" in payload assert "weights" not in payload assert tuple(payload["logits"].shape) == (3, 2) - assert context.state["feddf_proxy_indices"] == [1, 2, 4] + assert loaded_weights == [inbound_payload["weights"]] + assert torch.equal(context.state["feddf_proxy_inputs"], proxy_inputs) diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index f0dd9acee..c76887dd9 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -60,28 +60,35 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.linear(inputs) +class SharedProxyDatasource: + """Datasource stub exposing a shared test split for FedDF.""" + + def __init__(self, proxy_inputs: torch.Tensor) -> None: + self._test = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) + + def get_test_set(self): + return self._test + + def test_feddf_server_process_reports_distills_global_model(temp_config): """FedDF should consume logits payloads and update the global model.""" - from plato.servers import fedavg + feddf_server = _load_module( + "feddf_server_module", + _FEDDF_DIR / "feddf_server.py", + ) Config().server.do_test = False trainer = SimpleNamespace(model=TinyStudent(), device="cpu") algorithm = feddf_algorithm.Algorithm(trainer=trainer) strategy = feddf_server_strategy.FedDFAggregationStrategy( + proxy_set_size=4, + proxy_seed=1, temperature=1.0, distillation_epochs=80, distillation_batch_size=2, distillation_learning_rate=0.4, ) - server = fedavg.Server(aggregation_strategy=strategy) - server.algorithm = algorithm - server.trainer = trainer - server.context.server = server - server.context.algorithm = algorithm - server.context.trainer = trainer - server.context.state["prng_state"] = None - proxy_inputs = torch.tensor( [ [2.0, 0.0], @@ -90,8 +97,19 @@ def test_feddf_server_process_reports_distills_global_model(temp_config): [0.2, 1.5], ] ) - proxy_dataset = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) - server.context.state["feddf_proxy_dataset"] = proxy_dataset + server = feddf_server.Server( + aggregation_strategy=strategy, + datasource=lambda: SharedProxyDatasource(proxy_inputs), + ) + server.algorithm = algorithm + server.trainer = trainer + server.context.server = server + server.context.algorithm = algorithm + server.context.trainer = trainer + server.context.state["prng_state"] = None + server_payload = server.customize_server_payload(algorithm.extract_weights()) + assert set(server_payload.keys()) == {"weights", "proxy_inputs"} + assert torch.equal(server_payload["proxy_inputs"], proxy_inputs) teacher_logits_a = torch.tensor( [ From 0dce1448b125b05341b0d3db2916a9ac1723c2c3 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 01:29:42 -0400 Subject: [PATCH 3/7] Reserved deterministic FedDF proxy holdouts. Prefer the datasource unlabeled split for FedDF proxy selection before falling back to the test split. Added deterministic Torchvision subset slicing so CIFAR-10 can reserve a non-overlapping training holdout for proxy distillation while preserving classes and targets metadata. Expanded datasource and FedDF strategy coverage for the new subset semantics and unlabeled-first proxy path. --- .../feddf/feddf_server_strategy.py | 17 ++- plato/datasources/torchvision.py | 116 ++++++++++++++++-- .../test_torchvision_datasource.py | 59 +++++++++ tests/servers/test_feddf_server_strategy.py | 29 ++++- 4 files changed, 205 insertions(+), 16 deletions(-) diff --git a/examples/server_aggregation/feddf/feddf_server_strategy.py b/examples/server_aggregation/feddf/feddf_server_strategy.py index 70a6f439a..73b1d6f58 100644 --- a/examples/server_aggregation/feddf/feddf_server_strategy.py +++ b/examples/server_aggregation/feddf/feddf_server_strategy.py @@ -36,6 +36,21 @@ async def aggregate_deltas(self, updates, deltas_received, context: ServerContex """FedDF does not aggregate parameter deltas.""" raise NotImplementedError("FedDF uses aggregate_weights with logits payloads.") + @staticmethod + def _proxy_source_dataset(datasource): + """Prefer an unlabeled proxy split, falling back to the test split.""" + if hasattr(datasource, "get_unlabeled_set"): + unlabeled_set = datasource.get_unlabeled_set() + if unlabeled_set is not None: + return unlabeled_set + + test_set = datasource.get_test_set() + if test_set is None: + raise RuntimeError( + "FedDF requires either an unlabeled proxy split or a test split." + ) + return test_set + def _resolve_proxy_dataset(self, context: ServerContext) -> Dataset: """Construct or reuse the deterministic proxy subset.""" cached = context.state.get("feddf_proxy_dataset") @@ -60,7 +75,7 @@ def _resolve_proxy_dataset(self, context: ServerContext) -> Dataset: ) proxy_seed = resolve_algorithm_value("proxy_seed", self.proxy_seed, 1) proxy_dataset, proxy_indices = select_proxy_subset( - datasource.get_test_set(), + self._proxy_source_dataset(datasource), size=proxy_set_size, seed=proxy_seed, ) diff --git a/plato/datasources/torchvision.py b/plato/datasources/torchvision.py index 5a6eb116d..64f2c2dcc 100644 --- a/plato/datasources/torchvision.py +++ b/plato/datasources/torchvision.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional import torch +from torch.utils.data import Subset from torchvision import datasets, transforms from plato.config import Config @@ -52,6 +53,30 @@ def _to_plain_list(value: Any) -> list[Any]: raise TypeError("Expected a list-like object for dataset arguments.") +def _normalize_subset_spec(value: Any) -> dict[str, int] | None: + """Normalize an optional subset-selection config into plain integers.""" + spec = _to_plain_dict(value) + if not spec: + return None + + normalized: dict[str, int] = {} + if "seed" in spec: + normalized["seed"] = int(spec["seed"]) + + start = int(spec.get("start", 0)) + if start < 0: + raise ValueError("Subset start must be non-negative.") + normalized["start"] = start + + if "size" in spec: + size = int(spec["size"]) + if size < 0: + raise ValueError("Subset size must be non-negative.") + normalized["size"] = size + + return normalized + + def _default_transform(): """Factory providing a fresh default transform for image datasets.""" return transforms.ToTensor() @@ -264,6 +289,11 @@ def __init__(self, **kwargs): user_unlabeled_kwargs = _to_plain_dict( getattr(data_cfg, "unlabeled_kwargs", None) ) + train_subset = _normalize_subset_spec(getattr(data_cfg, "train_subset", None)) + test_subset = _normalize_subset_spec(getattr(data_cfg, "test_subset", None)) + unlabeled_subset = _normalize_subset_spec( + getattr(data_cfg, "unlabeled_subset", None) + ) common_args = list(dataset_defaults.get("dataset_args", [])) + _to_plain_list( getattr(data_cfg, "dataset_args", None) @@ -355,6 +385,9 @@ def __init__(self, **kwargs): train_transform, train_target_transform, ) + self.trainset = self._subset_dataset( + self.trainset, train_subset, subset_name="train" + ) else: self.trainset = None @@ -375,6 +408,9 @@ def __init__(self, **kwargs): test_transform, test_target_transform, ) + self.testset = self._subset_dataset( + self.testset, test_subset, subset_name="test" + ) else: self.testset = None @@ -395,6 +431,9 @@ def __init__(self, **kwargs): unlabeled_transform, unlabeled_target_transform, ) + self.unlabeledset = self._subset_dataset( + self.unlabeledset, unlabeled_subset, subset_name="unlabeled" + ) else: self.unlabeledset = None @@ -525,25 +564,82 @@ def _attach_metadata(dataset): if not hasattr(dataset, "classes") and hasattr(dataset, "class_to_idx"): dataset.classes = list(dataset.class_to_idx.keys()) - def classes(self): - dataset = self.trainset or self.testset - if dataset is None: - return [] + @staticmethod + def _dataset_classes(dataset): if hasattr(dataset, "classes") and dataset.classes is not None: return list(dataset.classes) if hasattr(dataset, "class_to_idx"): return list(dataset.class_to_idx.keys()) - return [] + return None + + @staticmethod + def _dataset_targets(dataset): + targets = None + if hasattr(dataset, "targets"): + targets = dataset.targets + elif hasattr(dataset, "labels"): + targets = dataset.labels + + if targets is None: + return None + if isinstance(targets, torch.Tensor): + return targets.tolist() + if isinstance(targets, tuple): + return list(targets) + if hasattr(targets, "tolist") and not isinstance(targets, list): + return targets.tolist() + return list(targets) + + def _subset_dataset(self, dataset, subset_spec, *, subset_name: str): + """Apply a deterministic subset slice while preserving metadata.""" + if dataset is None or subset_spec is None: + return dataset + + total_examples = len(dataset) + start = subset_spec.get("start", 0) + size = subset_spec.get("size", total_examples - start) + stop = start + size + + if start > total_examples: + raise ValueError( + f"{subset_name}_subset start {start} exceeds dataset size {total_examples}." + ) + if stop > total_examples: + raise ValueError( + f"{subset_name}_subset stop {stop} exceeds dataset size {total_examples}." + ) + + indices = list(range(total_examples)) + if "seed" in subset_spec: + generator = torch.Generator().manual_seed(subset_spec["seed"]) + indices = torch.randperm(total_examples, generator=generator).tolist() + + selected_indices = indices[start:stop] + subset = Subset(dataset, selected_indices) + + classes = self._dataset_classes(dataset) + if classes is not None: + subset.classes = classes + + targets = self._dataset_targets(dataset) + if targets is not None: + subset.targets = [targets[index] for index in selected_indices] + + return subset + + def classes(self): + dataset = self.trainset or self.testset + if dataset is None: + return [] + classes = self._dataset_classes(dataset) + return [] if classes is None else classes def targets(self): dataset = self.trainset or self.testset if dataset is None: return [] - if hasattr(dataset, "targets"): - return dataset.targets - if hasattr(dataset, "labels"): - return dataset.labels - return [] + targets = self._dataset_targets(dataset) + return [] if targets is None else targets def get_unlabeled_set(self): return getattr(self, "unlabeledset", None) diff --git a/tests/datasources/test_torchvision_datasource.py b/tests/datasources/test_torchvision_datasource.py index fbd52ed23..ada748dd3 100644 --- a/tests/datasources/test_torchvision_datasource.py +++ b/tests/datasources/test_torchvision_datasource.py @@ -189,6 +189,65 @@ def __getitem__(self, index): assert datasource.classes() == ["neg", "pos"] +def test_torchvision_datasource_supports_deterministic_non_overlapping_subsets( + monkeypatch, tmp_path +): + """Subset configs should carve deterministic, disjoint slices from one split.""" + + class DummyBoolDataset: + def __init__( + self, + root, + train=True, + download=False, + transform=None, + ): + self.root = root + self.train = train + self.download = download + self.transform = transform + self.targets = list(range(10)) + self.classes = ("neg", "pos") + self.data = list(range(10)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index], self.targets[index] + + stub_datasets = types.SimpleNamespace(DummyBoolDataset=DummyBoolDataset) + dummy_config = _build_config( + tmp_path, + { + "datasource": "Torchvision", + "dataset_name": "DummyBoolDataset", + "download": False, + "unlabeled_split": "train", + "train_subset": {"seed": 7, "start": 2, "size": 4}, + "unlabeled_subset": {"seed": 7, "start": 0, "size": 2}, + }, + ) + + monkeypatch.setattr(torchvision_ds, "datasets", stub_datasets) + monkeypatch.setattr(torchvision_ds, "transforms", _StubTransforms()) + monkeypatch.setattr(torchvision_ds, "Config", lambda: dummy_config) + + datasource = torchvision_ds.DataSource() + + expected_indices = torch.randperm( + 10, generator=torch.Generator().manual_seed(7) + ).tolist() + assert datasource.trainset.indices == expected_indices[2:6] + assert datasource.get_unlabeled_set().indices == expected_indices[:2] + assert set(datasource.trainset.indices).isdisjoint( + datasource.get_unlabeled_set().indices + ) + assert datasource.targets() == [3, 4, 1, 7] + assert datasource.get_unlabeled_set().targets == [5, 0] + assert datasource.classes() == ["neg", "pos"] + + def test_torchvision_datasource_celeba_defaults(monkeypatch, tmp_path): """CelebA should inherit legacy defaults including target handling.""" diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index c76887dd9..c763585d6 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -61,10 +61,21 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class SharedProxyDatasource: - """Datasource stub exposing a shared test split for FedDF.""" - - def __init__(self, proxy_inputs: torch.Tensor) -> None: - self._test = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) + """Datasource stub exposing both unlabeled and test splits for FedDF.""" + + def __init__( + self, + proxy_inputs: torch.Tensor, + test_inputs: torch.Tensor | None = None, + ) -> None: + self._unlabeled = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) + self._test = TensorDataset( + test_inputs if test_inputs is not None else proxy_inputs, + torch.zeros(len(test_inputs) if test_inputs is not None else len(proxy_inputs)), + ) + + def get_unlabeled_set(self): + return self._unlabeled def get_test_set(self): return self._test @@ -97,9 +108,17 @@ def test_feddf_server_process_reports_distills_global_model(temp_config): [0.2, 1.5], ] ) + held_out_test_inputs = torch.tensor( + [ + [9.0, 9.0], + [8.0, 8.0], + [7.0, 7.0], + [6.0, 6.0], + ] + ) server = feddf_server.Server( aggregation_strategy=strategy, - datasource=lambda: SharedProxyDatasource(proxy_inputs), + datasource=lambda: SharedProxyDatasource(proxy_inputs, held_out_test_inputs), ) server.algorithm = algorithm server.trainer = trainer From 66c4e617c15e0afbab00615d8bc496bb2c049ef3 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 06:55:42 -0400 Subject: [PATCH 4/7] Aligned FedDF distillation with uniform AVGLOGITS and Adam. Match the FedDF server distillation path more closely to the paper/reference recipe by defaulting teacher fusion to uniform AVGLOGITS, switching the student distillation loop to Adam with optional cosine annealing, and shuffling proxy batches during server training. Keep sample-weighted fusion and SGD available as explicit opt-in settings for experiments that need them. Add a focused server-side regression test to lock the uniform teacher-logit averaging default and keep the existing FedDF distillation test green under the new optimizer/scheduler path. Validation: uv run pytest tests/servers/test_feddf_server_strategy.py tests/clients/test_feddf_strategy.py -q; uv run ruff check examples/server_aggregation/feddf tests/servers/test_feddf_server_strategy.py tests/clients/test_feddf_strategy.py --- .../feddf/feddf_algorithm.py | 52 +++++++++++++++---- .../feddf/feddf_server_strategy.py | 35 +++++++++++-- tests/servers/test_feddf_server_strategy.py | 18 +++++++ 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/examples/server_aggregation/feddf/feddf_algorithm.py b/examples/server_aggregation/feddf/feddf_algorithm.py index 96db9cd42..57164529e 100644 --- a/examples/server_aggregation/feddf/feddf_algorithm.py +++ b/examples/server_aggregation/feddf/feddf_algorithm.py @@ -20,8 +20,10 @@ class Algorithm(fedavg.Algorithm): def aggregate_teacher_logits( updates, payloads: Sequence[Mapping[str, torch.Tensor]], + *, + weighting: str = "uniform", ) -> torch.Tensor: - """Compute a sample-weighted ensemble of client logits.""" + """Compute the ensembled teacher logits for AVGLOGITS distillation.""" if not payloads: raise ValueError("FedDF requires at least one logits payload.") @@ -29,9 +31,14 @@ def aggregate_teacher_logits( if not isinstance(first_logits, torch.Tensor): raise TypeError("FedDF payloads must include a 'logits' tensor.") + weighting_name = weighting.strip().lower() + if weighting_name not in {"uniform", "samples"}: + raise ValueError( + "FedDF teacher weighting must be either 'uniform' or 'samples'." + ) + total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates) - if total_samples <= 0: - total_samples = len(payloads) + use_uniform_average = weighting_name == "uniform" or total_samples <= 0 aggregated = torch.zeros_like(first_logits, dtype=torch.float32) @@ -44,9 +51,10 @@ def aggregate_teacher_logits( "FedDF client logits must share the same proxy-set shape." ) - weight = getattr(update.report, "num_samples", 0) / total_samples - if total_samples == len(payloads): + if use_uniform_average: weight = 1 / len(payloads) + else: + weight = getattr(update.report, "num_samples", 0) / total_samples aggregated += logits.detach().float() * weight @@ -62,6 +70,9 @@ def distill_weights( distillation_epochs: int, distillation_batch_size: int, distillation_learning_rate: float, + distillation_optimizer_name: str, + use_cosine_annealing: bool, + shuffle_batches: bool, ) -> OrderedDict[str, torch.Tensor]: """Distill the server model on proxy inputs using ensemble logits.""" if len(proxy_dataset) != len(teacher_logits): @@ -84,17 +95,36 @@ def distill_weights( dataloader = DataLoader( distillation_dataset, batch_size=distillation_batch_size, - shuffle=False, + shuffle=shuffle_batches, ) was_training = model.training model.to(device) model.train() - optimizer = torch.optim.SGD( - model.parameters(), - lr=distillation_learning_rate, - ) + optimizer_name = distillation_optimizer_name.strip().lower() + if optimizer_name == "adam": + optimizer = torch.optim.Adam( + model.parameters(), + lr=distillation_learning_rate, + ) + elif optimizer_name == "sgd": + optimizer = torch.optim.SGD( + model.parameters(), + lr=distillation_learning_rate, + ) + else: + raise ValueError( + "FedDF distillation optimizer must be either 'adam' or 'sgd'." + ) + + total_steps = max(distillation_epochs * len(dataloader), 1) + scheduler = None + if use_cosine_annealing: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=total_steps, + ) for _ in range(distillation_epochs): for batch_inputs, batch_logits in dataloader: @@ -112,6 +142,8 @@ def distill_weights( ) loss.backward() optimizer.step() + if scheduler is not None: + scheduler.step() if not was_training: model.eval() diff --git a/examples/server_aggregation/feddf/feddf_server_strategy.py b/examples/server_aggregation/feddf/feddf_server_strategy.py index 73b1d6f58..044bcd7d3 100644 --- a/examples/server_aggregation/feddf/feddf_server_strategy.py +++ b/examples/server_aggregation/feddf/feddf_server_strategy.py @@ -23,6 +23,10 @@ def __init__( distillation_epochs: int | None = None, distillation_batch_size: int | None = None, distillation_learning_rate: float | None = None, + teacher_weighting: str | None = None, + distillation_optimizer_name: str | None = None, + use_cosine_annealing: bool | None = None, + shuffle_batches: bool | None = None, ) -> None: super().__init__() self.proxy_set_size = proxy_set_size @@ -31,6 +35,10 @@ def __init__( self.distillation_epochs = distillation_epochs self.distillation_batch_size = distillation_batch_size self.distillation_learning_rate = distillation_learning_rate + self.teacher_weighting = teacher_weighting + self.distillation_optimizer_name = distillation_optimizer_name + self.use_cosine_annealing = use_cosine_annealing + self.shuffle_batches = shuffle_batches async def aggregate_deltas(self, updates, deltas_received, context: ServerContext): """FedDF does not aggregate parameter deltas.""" @@ -97,18 +105,36 @@ async def aggregate_weights( raise RuntimeError("FedDF requires an algorithm instance in context.") proxy_dataset = self._resolve_proxy_dataset(context) - teacher_logits = algorithm.aggregate_teacher_logits(updates, weights_received) + teacher_weighting = resolve_algorithm_value( + "teacher_weighting", self.teacher_weighting, "uniform" + ) + teacher_logits = algorithm.aggregate_teacher_logits( + updates, + weights_received, + weighting=teacher_weighting, + ) - temperature = resolve_algorithm_value("temperature", self.temperature, 2.0) + temperature = resolve_algorithm_value("temperature", self.temperature, 1.0) distillation_epochs = resolve_algorithm_value( "distillation_epochs", self.distillation_epochs, 5 ) distillation_batch_size = resolve_algorithm_value( - "distillation_batch_size", self.distillation_batch_size, 64 + "distillation_batch_size", self.distillation_batch_size, 128 ) distillation_learning_rate = resolve_algorithm_value( "learning_rate", self.distillation_learning_rate, 0.001 ) + distillation_optimizer_name = resolve_algorithm_value( + "distillation_optimizer_name", + self.distillation_optimizer_name, + "adam", + ) + use_cosine_annealing = resolve_algorithm_value( + "use_cosine_annealing", self.use_cosine_annealing, True + ) + shuffle_batches = resolve_algorithm_value( + "shuffle_batches", self.shuffle_batches, True + ) return algorithm.distill_weights( baseline_weights, @@ -118,4 +144,7 @@ async def aggregate_weights( distillation_epochs=distillation_epochs, distillation_batch_size=distillation_batch_size, distillation_learning_rate=distillation_learning_rate, + distillation_optimizer_name=distillation_optimizer_name, + use_cosine_annealing=use_cosine_annealing, + shuffle_batches=shuffle_batches, ) diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index c763585d6..c30cc9b3f 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -198,3 +198,21 @@ def test_feddf_server_process_reports_distills_global_model(temp_config): reduction="batchmean", ) assert distilled_loss < baseline_loss + + +def test_feddf_teacher_logits_average_uniformly_by_default(temp_config): + """FedDF should use uniform AVGLOGITS unless configured otherwise.""" + updates = [ + SimpleNamespace(report=SimpleNamespace(num_samples=1)), + SimpleNamespace(report=SimpleNamespace(num_samples=99)), + ] + teacher_logits_a = torch.tensor([[10.0, -10.0]]) + teacher_logits_b = torch.tensor([[-6.0, 6.0]]) + + aggregated = feddf_algorithm.Algorithm.aggregate_teacher_logits( + updates, + [{"logits": teacher_logits_a}, {"logits": teacher_logits_b}], + ) + + expected = (teacher_logits_a + teacher_logits_b) / 2 + assert torch.allclose(aggregated, expected) From 23c782db89c299c546f6b5e101c069b87e9e9fd2 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 12:43:01 -0400 Subject: [PATCH 5/7] Clarified MOON documentation taxonomy. Summary: - moved the detailed MOON walkthrough under customized client training loops - left the server-aggregation page with a pointer that explains MOON keeps FedAvg aggregation unchanged - aligned the example docs with the manuscript taxonomy used for the TKDE revision Validation: - documentation-only change; no runtime behavior changed --- .../1. Server Aggregation Algorithms.md | 36 +++-------------- ...s with Customized Client Training Loops.md | 39 +++++++++++++++++++ 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md index 05167eadb..19cd5a9a3 100644 --- a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md +++ b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md @@ -86,38 +86,12 @@ Key configuration parameters: ### MOON -MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level contrastive regularizer. Each client augments the shared model with a projection head, clones the incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints as negatives. The server still performs sample-weighted averaging but records a short history of global states for downstream analysis or warm restarts. +MOON is included in Plato as a **client-training customization** rather than as a server-aggregation +rule. The server still performs sample-weighted FedAvg; the distinguishing mechanism is the +contrastive local objective together with the historical-model buffer maintained by each client. -```bash -cd examples/server_aggregation/moon/ -uv run moon.py -c moon_MNIST_lenet5.toml -``` - -Key configuration parameters: - -- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0). -- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5). -- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2). -- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`). - -**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021. - -!!! note "Alignment with the paper" - Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON): - - - Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive- learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful—and helps the cosine similarities stay well behaved. - - - Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities. - - - Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available. - - - Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only. - - - Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons. - - The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities - (`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper- - parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference. +See **5. Algorithms with Customized Client Training Loops** for the runnable example, key +configuration parameters, and the implementation alignment notes for MOON. --- diff --git a/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md b/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md index 179b09e36..94ab100e1 100644 --- a/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md +++ b/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md @@ -58,6 +58,45 @@ uv run feddyn/feddyn.py -c feddyn/feddyn_MNIST_lenet5.toml --- +### MOON + +MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level +contrastive regularizer. Each client augments the shared model with a projection head, clones the +incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints +as negatives. The server still performs sample-weighted averaging but records a short history of +global states for downstream analysis or warm restarts. + +```bash +cd examples/server_aggregation/moon/ +uv run moon.py -c moon_MNIST_lenet5.toml +``` + +Key configuration parameters: + +- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0). +- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5). +- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2). +- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`). + +**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021. + +!!! note "Alignment with the paper" + Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON): + + - Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive-learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful and helps the cosine similarities stay well behaved. + + - Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities. + + - Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available. + + - Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only. + + - Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons. + + The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities (`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper-parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference. + +--- + ### FedMoS FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance. From 32efefdb0b826888ba8328ce63b898af88e660b1 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 12:44:03 -0400 Subject: [PATCH 6/7] Clarified MOON taxonomy heading. Summary: - marked the MOON pointer in the server-aggregation docs as a client-training customization - avoided leaving MOON visually classified as a server aggregation rule in the packaged docs Validation: - documentation-only change; no runtime behavior changed --- .../examples/algorithms/1. Server Aggregation Algorithms.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md index 19cd5a9a3..65a03b1a4 100644 --- a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md +++ b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md @@ -84,7 +84,7 @@ Key configuration parameters: --- -### MOON +### MOON (client-training customization) MOON is included in Plato as a **client-training customization** rather than as a server-aggregation rule. The server still performs sample-weighted FedAvg; the distinguishing mechanism is the From 3d04d89615d511d765302e922935ce0e5e0f563e Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 15 Apr 2026 13:14:02 -0400 Subject: [PATCH 7/7] Timed FedDF client and server overhead. Include the proxy-logit generation pass in the client-reported training time and record server distillation time on the FedDF server so logged round_time and elapsed_time reflect FedDF-specific work.\n\nAlso cache the stacked proxy inputs on the server and add focused tests for the FedDF timing path. --- .../server_aggregation/feddf/feddf_client.py | 8 ++++++ .../server_aggregation/feddf/feddf_server.py | 26 ++++++++++++++++++- .../feddf/feddf_server_strategy.py | 8 +++++- tests/clients/test_feddf_strategy.py | 9 ++++++- tests/servers/test_feddf_server_strategy.py | 8 ++++++ 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/examples/server_aggregation/feddf/feddf_client.py b/examples/server_aggregation/feddf/feddf_client.py index b0fcf6a59..23103e28f 100644 --- a/examples/server_aggregation/feddf/feddf_client.py +++ b/examples/server_aggregation/feddf/feddf_client.py @@ -2,6 +2,8 @@ from __future__ import annotations +import time + from feddf_utils import ( collect_proxy_logits, resolve_algorithm_value, @@ -50,12 +52,18 @@ async def train(self, context): proxy_batch_size = resolve_algorithm_value( "proxy_batch_size", self.proxy_batch_size, 128 ) + tic = time.perf_counter() logits = collect_proxy_logits( trainer.model, TensorDataset(proxy_inputs), batch_size=proxy_batch_size, device=getattr(trainer, "device", "cpu"), ) + proxy_logits_time = time.perf_counter() - tic + + report.training_time = getattr(report, "training_time", 0.0) + proxy_logits_time + report.feddf_proxy_logits_time = proxy_logits_time + context.state["feddf_proxy_logits_time"] = proxy_logits_time report.payload_type = "feddf_logits" report.proxy_size = len(proxy_inputs) diff --git a/examples/server_aggregation/feddf/feddf_server.py b/examples/server_aggregation/feddf/feddf_server.py index 192327b8d..3c3f05b60 100644 --- a/examples/server_aggregation/feddf/feddf_server.py +++ b/examples/server_aggregation/feddf/feddf_server.py @@ -2,6 +2,8 @@ from __future__ import annotations +import time + from feddf_algorithm import Algorithm as FedDFAlgorithm from feddf_server_strategy import FedDFAggregationStrategy from feddf_utils import stack_proxy_inputs @@ -36,14 +38,36 @@ def __init__( aggregation_strategy=aggregation_strategy, client_selection_strategy=client_selection_strategy, ) + self.feddf_server_distillation_time = 0.0 def customize_server_payload(self, payload): """Send weights together with the shared proxy inputs for FedDF.""" proxy_dataset = self.aggregation_strategy._resolve_proxy_dataset(self.context) - proxy_inputs = stack_proxy_inputs(proxy_dataset) + proxy_inputs = self.context.state.get("feddf_proxy_inputs") + if proxy_inputs is None: + proxy_inputs = stack_proxy_inputs(proxy_dataset) self.context.state["feddf_proxy_inputs"] = proxy_inputs return { "weights": payload, "proxy_inputs": proxy_inputs, } + + def clients_processed(self) -> None: + """Add server distillation time to the simulated round timing.""" + self.feddf_server_distillation_time = float( + self.context.state.pop("feddf_server_distillation_time", 0.0) + ) + + if self.simulate_wall_time: + self.wall_time += self.feddf_server_distillation_time + else: + self.wall_time = time.time() + + def get_logged_items(self) -> dict: + """Include FedDF server distillation in logged round metrics.""" + logged = super().get_logged_items() + logged["processing_time"] += self.feddf_server_distillation_time + logged["round_time"] += self.feddf_server_distillation_time + logged["feddf_server_distillation_time"] = self.feddf_server_distillation_time + return logged diff --git a/examples/server_aggregation/feddf/feddf_server_strategy.py b/examples/server_aggregation/feddf/feddf_server_strategy.py index 044bcd7d3..977d0f5b1 100644 --- a/examples/server_aggregation/feddf/feddf_server_strategy.py +++ b/examples/server_aggregation/feddf/feddf_server_strategy.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from typing import Mapping from feddf_utils import resolve_algorithm_value, select_proxy_subset @@ -136,7 +137,9 @@ async def aggregate_weights( "shuffle_batches", self.shuffle_batches, True ) - return algorithm.distill_weights( + context.state["feddf_server_distillation_time"] = 0.0 + tic = time.perf_counter() + updated_weights = algorithm.distill_weights( baseline_weights, teacher_logits, proxy_dataset, @@ -148,3 +151,6 @@ async def aggregate_weights( use_cosine_annealing=use_cosine_annealing, shuffle_batches=shuffle_batches, ) + context.state["feddf_server_distillation_time"] = time.perf_counter() - tic + + return updated_weights diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py index 66f9bd1b4..6676eda37 100644 --- a/tests/clients/test_feddf_strategy.py +++ b/tests/clients/test_feddf_strategy.py @@ -62,11 +62,17 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): feddf_client.DefaultTrainingStrategy, "train", new=async_mock, - ) as mock_train: + ) as mock_train, patch.object( + feddf_client.time, + "perf_counter", + side_effect=[10.0, 10.25], + ): report, payload = asyncio.run(strategy.train(context)) mock_train.assert_awaited_once() assert report is mock_report + assert getattr(report, "training_time") == 0.25 + assert getattr(report, "feddf_proxy_logits_time") == 0.25 assert getattr(report, "payload_type") == "feddf_logits" assert getattr(report, "proxy_size") == 3 assert "logits" in payload @@ -74,3 +80,4 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): assert tuple(payload["logits"].shape) == (3, 2) assert loaded_weights == [inbound_payload["weights"]] assert torch.equal(context.state["feddf_proxy_inputs"], proxy_inputs) + assert context.state["feddf_proxy_logits_time"] == 0.25 diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index c30cc9b3f..edf382b1d 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -198,6 +198,14 @@ def test_feddf_server_process_reports_distills_global_model(temp_config): reduction="batchmean", ) assert distilled_loss < baseline_loss + assert server.feddf_server_distillation_time > 0 + + logged_items = server.get_logged_items() + assert logged_items["feddf_server_distillation_time"] == ( + server.feddf_server_distillation_time + ) + assert logged_items["round_time"] == server.feddf_server_distillation_time + assert logged_items["elapsed_time"] >= server.feddf_server_distillation_time def test_feddf_teacher_logits_average_uniformly_by_default(temp_config):