diff --git a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md index ebf721219..65a03b1a4 100644 --- a/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md +++ b/docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md @@ -56,40 +56,42 @@ Key configuration parameters: --- -### MOON +### FedDF -MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level contrastive regularizer. Each client augments the shared model with a projection head, clones the incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints as negatives. The server still performs sample-weighted averaging but records a short history of global states for downstream analysis or warm restarts. +FedDF (Federated Distillation and Fusion) replaces direct parameter averaging with server-side distillation on a proxy set. The server selects a deterministic unlabeled proxy subset, ships those proxy inputs alongside the current global weights, and each client returns teacher logits on that shared payload instead of ordinary weight deltas. The server then aggregates those logits and distills their ensemble into the next global model using temperature-scaled soft targets. ```bash -cd examples/server_aggregation/moon/ -uv run moon.py -c moon_MNIST_lenet5.toml +cd examples/server_aggregation/feddf/ +uv run feddf.py -c feddf_MNIST_lenet5.toml ``` Key configuration parameters: -- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0). -- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5). -- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2). -- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`). +- `algorithm.proxy_set_size`: Number of unlabeled proxy samples used on the server for distillation. +- `algorithm.proxy_batch_size`: Batch size for iterating through the proxy set. +- `algorithm.proxy_seed`: Seed used to select the deterministic proxy subset shared by clients and server. +- `algorithm.temperature`: Softmax temperature used to smooth teacher logits before distillation. +- `algorithm.distillation_epochs`: Number of server-side distillation passes per round. +- `algorithm.distillation_batch_size`: Batch size for the distillation optimizer. +- `algorithm.learning_rate`: Learning rate for the server-side distillation optimizer. -**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021. +**Reference:** Tao Lin, Lingjing Kong, Sebastian U. Stich, Martin Jaggi. "[Ensemble Distillation for Robust Model Fusion in Federated Learning](https://arxiv.org/abs/2006.07242)," arXiv:2006.07242, 2020. !!! note "Alignment with the paper" - Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON): + The module split follows the FedDF workflow directly: `feddf.py` stays as a thin launcher, `feddf_server.py` packages the current global weights together with the shared proxy inputs, `feddf_client.py` performs the standard local update and then emits logits on that server-supplied proxy payload, `feddf_server_strategy.py` resolves the deterministic proxy subset and routes the logits payload through direct weight aggregation, and `feddf_algorithm.py` encapsulates the weighted-logit ensemble plus the temperature-scaled KL distillation step. - - Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive- learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful—and helps the cosine similarities stay well behaved. + The configuration surface above mirrors the paper’s core knobs. `proxy_set_size`, `proxy_batch_size`, and `proxy_seed` control the deterministic unlabeled proxy data used for ensemble distillation, `temperature` shapes the softened teacher distribution, and `distillation_epochs`, `distillation_batch_size`, and `learning_rate` control the server-side student optimization that replaces direct averaging. - - Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities. - - - Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available. +--- - - Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only. +### MOON (client-training customization) - - Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons. +MOON is included in Plato as a **client-training customization** rather than as a server-aggregation +rule. The server still performs sample-weighted FedAvg; the distinguishing mechanism is the +contrastive local objective together with the historical-model buffer maintained by each client. - The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities - (`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper- - parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference. +See **5. Algorithms with Customized Client Training Loops** for the runnable example, key +configuration parameters, and the implementation alignment notes for MOON. --- diff --git a/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md b/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md index 179b09e36..94ab100e1 100644 --- a/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md +++ b/docs/docs/examples/algorithms/5. Algorithms with Customized Client Training Loops.md @@ -58,6 +58,45 @@ uv run feddyn/feddyn.py -c feddyn/feddyn_MNIST_lenet5.toml --- +### MOON + +MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level +contrastive regularizer. Each client augments the shared model with a projection head, clones the +incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints +as negatives. The server still performs sample-weighted averaging but records a short history of +global states for downstream analysis or warm restarts. + +```bash +cd examples/server_aggregation/moon/ +uv run moon.py -c moon_MNIST_lenet5.toml +``` + +Key configuration parameters: + +- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0). +- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5). +- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2). +- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`). + +**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021. + +!!! note "Alignment with the paper" + Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON): + + - Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive-learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful and helps the cosine similarities stay well behaved. + + - Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities. + + - Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available. + + - Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only. + + - Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons. + + The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities (`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper-parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference. + +--- + ### FedMoS FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance. diff --git a/examples/server_aggregation/feddf/feddf.py b/examples/server_aggregation/feddf/feddf.py new file mode 100644 index 000000000..ab6701601 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf.py @@ -0,0 +1,19 @@ +""" +Entry point for running the FedDF server aggregation example. +""" + +from __future__ import annotations + +import feddf_client +import feddf_server + + +def main(): + """Launch a Plato training session with the FedDF algorithm.""" + client = feddf_client.create_client() + server = feddf_server.Server() + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml b/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml new file mode 100644 index 000000000..11b2b6c68 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml @@ -0,0 +1,64 @@ +[clients] +type = "simple" +total_clients = 100 +per_round = 10 +random_seed = 1 +do_test = false +speed_simulation = true +sleep_simulation = true +avg_training_time = 20 + +[clients.simulation_distribution] +distribution = "normal" +mean = 10 +sd = 3 + +[server] +address = "127.0.0.1" +port = 8000 +synchronous = false +simulate_wall_time = true +minimum_clients_aggregated = 6 +staleness_bound = 10 +random_seed = 1 + +[data] +datasource = "Torchvision" +dataset_name = "MNIST" +download = true +data_path = "data" +partition_size = 200 +sampler = "noniid" +concentration = 0.5 +random_seed = 1 + +[trainer] +rounds = 20 +type = "basic" +max_concurrency = 4 +target_accuracy = 1.0 +epochs = 5 +batch_size = 64 +optimizer = "SGD" +model_name = "lenet5" + +[algorithm] +type = "fedavg" +proxy_set_size = 2048 +proxy_batch_size = 128 +proxy_seed = 1 +temperature = 2.0 +distillation_epochs = 5 +distillation_batch_size = 64 +learning_rate = 0.001 + +[parameters] + +[parameters.optimizer] +lr = 0.01 +momentum = 0.9 +weight_decay = 0.0 + +[results] +result_path = "results/MNIST_lenet5/FedDF" +types = "round, accuracy, elapsed_time, comm_time, round_time" diff --git a/examples/server_aggregation/feddf/feddf_algorithm.py b/examples/server_aggregation/feddf/feddf_algorithm.py new file mode 100644 index 000000000..57164529e --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_algorithm.py @@ -0,0 +1,154 @@ +"""FedDF-specific helpers for ensemble distillation on the server.""" + +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from feddf_utils import extract_batch_inputs, unwrap_model_outputs +from torch.utils.data import DataLoader, Dataset, TensorDataset + +from plato.algorithms import fedavg + + +class Algorithm(fedavg.Algorithm): + """Algorithm helpers for aggregating logits and distilling the student.""" + + @staticmethod + def aggregate_teacher_logits( + updates, + payloads: Sequence[Mapping[str, torch.Tensor]], + *, + weighting: str = "uniform", + ) -> torch.Tensor: + """Compute the ensembled teacher logits for AVGLOGITS distillation.""" + if not payloads: + raise ValueError("FedDF requires at least one logits payload.") + + first_logits = payloads[0].get("logits") + if not isinstance(first_logits, torch.Tensor): + raise TypeError("FedDF payloads must include a 'logits' tensor.") + + weighting_name = weighting.strip().lower() + if weighting_name not in {"uniform", "samples"}: + raise ValueError( + "FedDF teacher weighting must be either 'uniform' or 'samples'." + ) + + total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates) + use_uniform_average = weighting_name == "uniform" or total_samples <= 0 + + aggregated = torch.zeros_like(first_logits, dtype=torch.float32) + + for update, payload in zip(updates, payloads): + logits = payload.get("logits") + if not isinstance(logits, torch.Tensor): + raise TypeError("FedDF payloads must include a 'logits' tensor.") + if logits.shape != first_logits.shape: + raise ValueError( + "FedDF client logits must share the same proxy-set shape." + ) + + if use_uniform_average: + weight = 1 / len(payloads) + else: + weight = getattr(update.report, "num_samples", 0) / total_samples + + aggregated += logits.detach().float() * weight + + return aggregated + + def distill_weights( + self, + baseline_weights: Mapping[str, torch.Tensor], + teacher_logits: torch.Tensor, + proxy_dataset: Dataset, + *, + temperature: float, + distillation_epochs: int, + distillation_batch_size: int, + distillation_learning_rate: float, + distillation_optimizer_name: str, + use_cosine_annealing: bool, + shuffle_batches: bool, + ) -> OrderedDict[str, torch.Tensor]: + """Distill the server model on proxy inputs using ensemble logits.""" + if len(proxy_dataset) != len(teacher_logits): + raise ValueError( + "FedDF proxy samples and teacher logits must have matching lengths." + ) + + trainer = self.require_trainer() + model = self.require_model() + device = torch.device(getattr(trainer, "device", "cpu")) + + self.load_weights(baseline_weights) + + inputs = [] + for example in proxy_dataset: + inputs.append(extract_batch_inputs(example)) + + proxy_inputs = torch.stack(inputs) + distillation_dataset = TensorDataset(proxy_inputs, teacher_logits.detach().cpu()) + dataloader = DataLoader( + distillation_dataset, + batch_size=distillation_batch_size, + shuffle=shuffle_batches, + ) + + was_training = model.training + model.to(device) + model.train() + + optimizer_name = distillation_optimizer_name.strip().lower() + if optimizer_name == "adam": + optimizer = torch.optim.Adam( + model.parameters(), + lr=distillation_learning_rate, + ) + elif optimizer_name == "sgd": + optimizer = torch.optim.SGD( + model.parameters(), + lr=distillation_learning_rate, + ) + else: + raise ValueError( + "FedDF distillation optimizer must be either 'adam' or 'sgd'." + ) + + total_steps = max(distillation_epochs * len(dataloader), 1) + scheduler = None + if use_cosine_annealing: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=total_steps, + ) + + for _ in range(distillation_epochs): + for batch_inputs, batch_logits in dataloader: + batch_inputs = batch_inputs.to(device) + batch_logits = batch_logits.to(device) + teacher_probs = torch.softmax(batch_logits / temperature, dim=1) + + optimizer.zero_grad() + student_logits = unwrap_model_outputs(model(batch_inputs)) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + loss = ( + F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") + * temperature + * temperature + ) + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + + if not was_training: + model.eval() + + return OrderedDict( + (name, tensor.detach().cpu().clone()) + for name, tensor in model.state_dict().items() + ) diff --git a/examples/server_aggregation/feddf/feddf_client.py b/examples/server_aggregation/feddf/feddf_client.py new file mode 100644 index 000000000..23103e28f --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_client.py @@ -0,0 +1,107 @@ +"""Client implementation for the FedDF server aggregation example.""" + +from __future__ import annotations + +import time + +from feddf_utils import ( + collect_proxy_logits, + resolve_algorithm_value, +) +from torch.utils.data import TensorDataset + +from plato.clients import simple +from plato.clients.strategies.defaults import DefaultTrainingStrategy + + +class FedDFTrainingStrategy(DefaultTrainingStrategy): + """Train locally, then emit teacher logits on a shared proxy set.""" + + def __init__( + self, + *, + proxy_batch_size: int | None = None, + ) -> None: + super().__init__() + self.proxy_batch_size = proxy_batch_size + + def load_payload(self, context, server_payload) -> None: + """Load model weights and cache the shared proxy inputs from the server.""" + if not isinstance(server_payload, dict): + raise TypeError("FedDF expects a dictionary payload from the server.") + + if "weights" not in server_payload or "proxy_inputs" not in server_payload: + raise KeyError( + "FedDF server payload must include 'weights' and 'proxy_inputs'." + ) + + context.state["feddf_proxy_inputs"] = server_payload["proxy_inputs"] + super().load_payload(context, server_payload["weights"]) + + async def train(self, context): + report, _ = await super().train(context) + + trainer = getattr(context, "trainer", None) + if trainer is None or getattr(trainer, "model", None) is None: + raise RuntimeError("FedDF requires a trainer with a model for logits.") + + proxy_inputs = context.state.get("feddf_proxy_inputs") + if proxy_inputs is None: + raise RuntimeError("FedDF requires shared proxy inputs from the server.") + + proxy_batch_size = resolve_algorithm_value( + "proxy_batch_size", self.proxy_batch_size, 128 + ) + tic = time.perf_counter() + logits = collect_proxy_logits( + trainer.model, + TensorDataset(proxy_inputs), + batch_size=proxy_batch_size, + device=getattr(trainer, "device", "cpu"), + ) + proxy_logits_time = time.perf_counter() - tic + + report.training_time = getattr(report, "training_time", 0.0) + proxy_logits_time + report.feddf_proxy_logits_time = proxy_logits_time + context.state["feddf_proxy_logits_time"] = proxy_logits_time + + report.payload_type = "feddf_logits" + report.proxy_size = len(proxy_inputs) + + return report, {"logits": logits} + + +def create_client( + *, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + trainer_callbacks=None, + proxy_batch_size: int | None = None, +): + """Build a client configured to emit FedDF proxy-set logits.""" + client = simple.Client( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + trainer_callbacks=trainer_callbacks, + ) + + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=FedDFTrainingStrategy( + proxy_batch_size=proxy_batch_size, + ), + reporting_strategy=client.reporting_strategy, + communication_strategy=client.communication_strategy, + ) + + return client + + +Client = create_client diff --git a/examples/server_aggregation/feddf/feddf_server.py b/examples/server_aggregation/feddf/feddf_server.py new file mode 100644 index 000000000..3c3f05b60 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_server.py @@ -0,0 +1,73 @@ +"""Server wrapper for the FedDF server aggregation example.""" + +from __future__ import annotations + +import time + +from feddf_algorithm import Algorithm as FedDFAlgorithm +from feddf_server_strategy import FedDFAggregationStrategy +from feddf_utils import stack_proxy_inputs + +from plato.servers import fedavg + + +class Server(fedavg.Server): + """A federated learning server using FedDF distillation aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = FedDFAggregationStrategy() + + selected_algorithm = algorithm or FedDFAlgorithm + + super().__init__( + model=model, + datasource=datasource, + algorithm=selected_algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) + self.feddf_server_distillation_time = 0.0 + + def customize_server_payload(self, payload): + """Send weights together with the shared proxy inputs for FedDF.""" + proxy_dataset = self.aggregation_strategy._resolve_proxy_dataset(self.context) + proxy_inputs = self.context.state.get("feddf_proxy_inputs") + if proxy_inputs is None: + proxy_inputs = stack_proxy_inputs(proxy_dataset) + self.context.state["feddf_proxy_inputs"] = proxy_inputs + + return { + "weights": payload, + "proxy_inputs": proxy_inputs, + } + + def clients_processed(self) -> None: + """Add server distillation time to the simulated round timing.""" + self.feddf_server_distillation_time = float( + self.context.state.pop("feddf_server_distillation_time", 0.0) + ) + + if self.simulate_wall_time: + self.wall_time += self.feddf_server_distillation_time + else: + self.wall_time = time.time() + + def get_logged_items(self) -> dict: + """Include FedDF server distillation in logged round metrics.""" + logged = super().get_logged_items() + logged["processing_time"] += self.feddf_server_distillation_time + logged["round_time"] += self.feddf_server_distillation_time + logged["feddf_server_distillation_time"] = self.feddf_server_distillation_time + return logged diff --git a/examples/server_aggregation/feddf/feddf_server_strategy.py b/examples/server_aggregation/feddf/feddf_server_strategy.py new file mode 100644 index 000000000..977d0f5b1 --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_server_strategy.py @@ -0,0 +1,156 @@ +"""FedDF aggregation strategy using server-side proxy-set distillation.""" + +from __future__ import annotations + +import time +from typing import Mapping + +from feddf_utils import resolve_algorithm_value, select_proxy_subset +from torch.utils.data import Dataset + +from plato.datasources import registry as datasources_registry +from plato.servers.strategies.base import AggregationStrategy, ServerContext + + +class FedDFAggregationStrategy(AggregationStrategy): + """Aggregate client logits and distill them into the global student.""" + + def __init__( + self, + *, + proxy_set_size: int | None = None, + proxy_seed: int | None = None, + temperature: float | None = None, + distillation_epochs: int | None = None, + distillation_batch_size: int | None = None, + distillation_learning_rate: float | None = None, + teacher_weighting: str | None = None, + distillation_optimizer_name: str | None = None, + use_cosine_annealing: bool | None = None, + shuffle_batches: bool | None = None, + ) -> None: + super().__init__() + self.proxy_set_size = proxy_set_size + self.proxy_seed = proxy_seed + self.temperature = temperature + self.distillation_epochs = distillation_epochs + self.distillation_batch_size = distillation_batch_size + self.distillation_learning_rate = distillation_learning_rate + self.teacher_weighting = teacher_weighting + self.distillation_optimizer_name = distillation_optimizer_name + self.use_cosine_annealing = use_cosine_annealing + self.shuffle_batches = shuffle_batches + + async def aggregate_deltas(self, updates, deltas_received, context: ServerContext): + """FedDF does not aggregate parameter deltas.""" + raise NotImplementedError("FedDF uses aggregate_weights with logits payloads.") + + @staticmethod + def _proxy_source_dataset(datasource): + """Prefer an unlabeled proxy split, falling back to the test split.""" + if hasattr(datasource, "get_unlabeled_set"): + unlabeled_set = datasource.get_unlabeled_set() + if unlabeled_set is not None: + return unlabeled_set + + test_set = datasource.get_test_set() + if test_set is None: + raise RuntimeError( + "FedDF requires either an unlabeled proxy split or a test split." + ) + return test_set + + def _resolve_proxy_dataset(self, context: ServerContext) -> Dataset: + """Construct or reuse the deterministic proxy subset.""" + cached = context.state.get("feddf_proxy_dataset") + if cached is not None: + return cached + + server = getattr(context, "server", None) + if server is None: + raise RuntimeError("FedDF requires the server in strategy context.") + + datasource = getattr(server, "datasource", None) + if datasource is None: + custom_datasource = getattr(server, "custom_datasource", None) + if custom_datasource is not None: + datasource = custom_datasource() + else: + datasource = datasources_registry.get(client_id=0) + server.datasource = datasource + + proxy_set_size = resolve_algorithm_value( + "proxy_set_size", self.proxy_set_size, 512 + ) + proxy_seed = resolve_algorithm_value("proxy_seed", self.proxy_seed, 1) + proxy_dataset, proxy_indices = select_proxy_subset( + self._proxy_source_dataset(datasource), + size=proxy_set_size, + seed=proxy_seed, + ) + + context.state["feddf_proxy_dataset"] = proxy_dataset + context.state["feddf_proxy_indices"] = proxy_indices + return proxy_dataset + + async def aggregate_weights( + self, + updates, + baseline_weights: Mapping, + weights_received, + context: ServerContext, + ): + """Distill the global student from client logits on the proxy set.""" + algorithm = getattr(context, "algorithm", None) + if algorithm is None: + raise RuntimeError("FedDF requires an algorithm instance in context.") + + proxy_dataset = self._resolve_proxy_dataset(context) + teacher_weighting = resolve_algorithm_value( + "teacher_weighting", self.teacher_weighting, "uniform" + ) + teacher_logits = algorithm.aggregate_teacher_logits( + updates, + weights_received, + weighting=teacher_weighting, + ) + + temperature = resolve_algorithm_value("temperature", self.temperature, 1.0) + distillation_epochs = resolve_algorithm_value( + "distillation_epochs", self.distillation_epochs, 5 + ) + distillation_batch_size = resolve_algorithm_value( + "distillation_batch_size", self.distillation_batch_size, 128 + ) + distillation_learning_rate = resolve_algorithm_value( + "learning_rate", self.distillation_learning_rate, 0.001 + ) + distillation_optimizer_name = resolve_algorithm_value( + "distillation_optimizer_name", + self.distillation_optimizer_name, + "adam", + ) + use_cosine_annealing = resolve_algorithm_value( + "use_cosine_annealing", self.use_cosine_annealing, True + ) + shuffle_batches = resolve_algorithm_value( + "shuffle_batches", self.shuffle_batches, True + ) + + context.state["feddf_server_distillation_time"] = 0.0 + tic = time.perf_counter() + updated_weights = algorithm.distill_weights( + baseline_weights, + teacher_logits, + proxy_dataset, + temperature=temperature, + distillation_epochs=distillation_epochs, + distillation_batch_size=distillation_batch_size, + distillation_learning_rate=distillation_learning_rate, + distillation_optimizer_name=distillation_optimizer_name, + use_cosine_annealing=use_cosine_annealing, + shuffle_batches=shuffle_batches, + ) + context.state["feddf_server_distillation_time"] = time.perf_counter() - tic + + return updated_weights diff --git a/examples/server_aggregation/feddf/feddf_utils.py b/examples/server_aggregation/feddf/feddf_utils.py new file mode 100644 index 000000000..bd3be1c4c --- /dev/null +++ b/examples/server_aggregation/feddf/feddf_utils.py @@ -0,0 +1,101 @@ +"""Shared proxy-set helpers for the FedDF server aggregation example.""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch.utils.data import DataLoader, Dataset, Subset + +from plato.config import Config + + +def resolve_algorithm_value(name: str, explicit_value: Any, default: Any) -> Any: + """Resolve an example parameter from the constructor or config file.""" + if explicit_value is not None: + return explicit_value + + algorithm_cfg = getattr(Config(), "algorithm", None) + if algorithm_cfg is None: + return default + + return getattr(algorithm_cfg, name, default) + + +def select_proxy_subset( + dataset: Dataset, + *, + size: int, + seed: int, +) -> tuple[Subset, list[int]]: + """Build a deterministic subset of the shared proxy dataset.""" + total_examples = len(dataset) + if total_examples == 0: + raise ValueError("FedDF proxy dataset is empty.") + + subset_size = min(size, total_examples) + generator = torch.Generator().manual_seed(seed) + indices = torch.randperm(total_examples, generator=generator)[:subset_size].tolist() + indices.sort() + + return Subset(dataset, indices), indices + + +def unwrap_model_outputs(outputs: Any) -> torch.Tensor: + """Normalise model outputs to a logits tensor.""" + if isinstance(outputs, (tuple, list)): + return outputs[0] + if not isinstance(outputs, torch.Tensor): + raise TypeError( + "FedDF expects the model forward pass to return a tensor or " + f"tensor-like tuple, received {type(outputs).__name__}." + ) + return outputs + + +def extract_batch_inputs(batch: Any) -> Any: + """Return the model input tensor from a dataset batch.""" + if isinstance(batch, (tuple, list)): + return batch[0] + if isinstance(batch, dict): + for key in ("input", "inputs", "image", "images", "x"): + if key in batch: + return batch[key] + return batch + + +def collect_proxy_logits( + model: torch.nn.Module, + proxy_dataset: Dataset, + *, + batch_size: int, + device: torch.device | str, +) -> torch.Tensor: + """Evaluate a model on the proxy set and return detached logits.""" + was_training = model.training + device = torch.device(device) + dataloader = DataLoader(proxy_dataset, batch_size=batch_size, shuffle=False) + logits: list[torch.Tensor] = [] + + model.to(device) + model.eval() + + with torch.no_grad(): + for batch in dataloader: + inputs = extract_batch_inputs(batch).to(device) + batch_logits = unwrap_model_outputs(model(inputs)) + logits.append(batch_logits.detach().cpu()) + + if was_training: + model.train() + + return torch.cat(logits, dim=0) + + +def stack_proxy_inputs(proxy_dataset: Dataset) -> torch.Tensor: + """Materialise the proxy inputs in their deterministic dataset order.""" + inputs = [] + for example in proxy_dataset: + inputs.append(extract_batch_inputs(example)) + + return torch.stack(inputs) diff --git a/plato/datasources/torchvision.py b/plato/datasources/torchvision.py index 5a6eb116d..64f2c2dcc 100644 --- a/plato/datasources/torchvision.py +++ b/plato/datasources/torchvision.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional import torch +from torch.utils.data import Subset from torchvision import datasets, transforms from plato.config import Config @@ -52,6 +53,30 @@ def _to_plain_list(value: Any) -> list[Any]: raise TypeError("Expected a list-like object for dataset arguments.") +def _normalize_subset_spec(value: Any) -> dict[str, int] | None: + """Normalize an optional subset-selection config into plain integers.""" + spec = _to_plain_dict(value) + if not spec: + return None + + normalized: dict[str, int] = {} + if "seed" in spec: + normalized["seed"] = int(spec["seed"]) + + start = int(spec.get("start", 0)) + if start < 0: + raise ValueError("Subset start must be non-negative.") + normalized["start"] = start + + if "size" in spec: + size = int(spec["size"]) + if size < 0: + raise ValueError("Subset size must be non-negative.") + normalized["size"] = size + + return normalized + + def _default_transform(): """Factory providing a fresh default transform for image datasets.""" return transforms.ToTensor() @@ -264,6 +289,11 @@ def __init__(self, **kwargs): user_unlabeled_kwargs = _to_plain_dict( getattr(data_cfg, "unlabeled_kwargs", None) ) + train_subset = _normalize_subset_spec(getattr(data_cfg, "train_subset", None)) + test_subset = _normalize_subset_spec(getattr(data_cfg, "test_subset", None)) + unlabeled_subset = _normalize_subset_spec( + getattr(data_cfg, "unlabeled_subset", None) + ) common_args = list(dataset_defaults.get("dataset_args", [])) + _to_plain_list( getattr(data_cfg, "dataset_args", None) @@ -355,6 +385,9 @@ def __init__(self, **kwargs): train_transform, train_target_transform, ) + self.trainset = self._subset_dataset( + self.trainset, train_subset, subset_name="train" + ) else: self.trainset = None @@ -375,6 +408,9 @@ def __init__(self, **kwargs): test_transform, test_target_transform, ) + self.testset = self._subset_dataset( + self.testset, test_subset, subset_name="test" + ) else: self.testset = None @@ -395,6 +431,9 @@ def __init__(self, **kwargs): unlabeled_transform, unlabeled_target_transform, ) + self.unlabeledset = self._subset_dataset( + self.unlabeledset, unlabeled_subset, subset_name="unlabeled" + ) else: self.unlabeledset = None @@ -525,25 +564,82 @@ def _attach_metadata(dataset): if not hasattr(dataset, "classes") and hasattr(dataset, "class_to_idx"): dataset.classes = list(dataset.class_to_idx.keys()) - def classes(self): - dataset = self.trainset or self.testset - if dataset is None: - return [] + @staticmethod + def _dataset_classes(dataset): if hasattr(dataset, "classes") and dataset.classes is not None: return list(dataset.classes) if hasattr(dataset, "class_to_idx"): return list(dataset.class_to_idx.keys()) - return [] + return None + + @staticmethod + def _dataset_targets(dataset): + targets = None + if hasattr(dataset, "targets"): + targets = dataset.targets + elif hasattr(dataset, "labels"): + targets = dataset.labels + + if targets is None: + return None + if isinstance(targets, torch.Tensor): + return targets.tolist() + if isinstance(targets, tuple): + return list(targets) + if hasattr(targets, "tolist") and not isinstance(targets, list): + return targets.tolist() + return list(targets) + + def _subset_dataset(self, dataset, subset_spec, *, subset_name: str): + """Apply a deterministic subset slice while preserving metadata.""" + if dataset is None or subset_spec is None: + return dataset + + total_examples = len(dataset) + start = subset_spec.get("start", 0) + size = subset_spec.get("size", total_examples - start) + stop = start + size + + if start > total_examples: + raise ValueError( + f"{subset_name}_subset start {start} exceeds dataset size {total_examples}." + ) + if stop > total_examples: + raise ValueError( + f"{subset_name}_subset stop {stop} exceeds dataset size {total_examples}." + ) + + indices = list(range(total_examples)) + if "seed" in subset_spec: + generator = torch.Generator().manual_seed(subset_spec["seed"]) + indices = torch.randperm(total_examples, generator=generator).tolist() + + selected_indices = indices[start:stop] + subset = Subset(dataset, selected_indices) + + classes = self._dataset_classes(dataset) + if classes is not None: + subset.classes = classes + + targets = self._dataset_targets(dataset) + if targets is not None: + subset.targets = [targets[index] for index in selected_indices] + + return subset + + def classes(self): + dataset = self.trainset or self.testset + if dataset is None: + return [] + classes = self._dataset_classes(dataset) + return [] if classes is None else classes def targets(self): dataset = self.trainset or self.testset if dataset is None: return [] - if hasattr(dataset, "targets"): - return dataset.targets - if hasattr(dataset, "labels"): - return dataset.labels - return [] + targets = self._dataset_targets(dataset) + return [] if targets is None else targets def get_unlabeled_set(self): return getattr(self, "unlabeledset", None) diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py new file mode 100644 index 000000000..6676eda37 --- /dev/null +++ b/tests/clients/test_feddf_strategy.py @@ -0,0 +1,83 @@ +"""Tests for the FedDF example training strategy.""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import torch + +from tests.test_utils.fakes import FakeModel + +_TESTS_ROOT = Path(__file__).resolve().parent +_FEDDF_DIR = ( + _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" +) +if str(_FEDDF_DIR) not in sys.path: + sys.path.insert(0, str(_FEDDF_DIR)) + +_FEDDF_CLIENT_PATH = _FEDDF_DIR / "feddf_client.py" +_FEDDF_SPEC = importlib.util.spec_from_file_location( + "feddf_client_module", _FEDDF_CLIENT_PATH +) +if _FEDDF_SPEC is None: + raise RuntimeError(f"Unable to load spec for {_FEDDF_CLIENT_PATH}") + +feddf_client = cast(Any, importlib.util.module_from_spec(_FEDDF_SPEC)) +loader = _FEDDF_SPEC.loader +if loader is None: + raise RuntimeError(f"Loader missing for {_FEDDF_CLIENT_PATH}") +loader.exec_module(feddf_client) + + +def test_feddf_training_strategy_returns_teacher_logits(temp_config): + """FedDF clients should send proxy-set logits instead of model weights.""" + strategy = feddf_client.FedDFTrainingStrategy( + proxy_batch_size=2, + ) + loaded_weights = [] + context = SimpleNamespace( + client_id=1, + current_round=1, + algorithm=SimpleNamespace(load_weights=lambda weights: loaded_weights.append(weights)), + trainer=SimpleNamespace(model=FakeModel(), device="cpu"), + state={}, + ) + proxy_inputs = torch.randn(3, 4) + inbound_payload = { + "weights": {"linear.weight": torch.ones(2, 4)}, + "proxy_inputs": proxy_inputs, + } + strategy.load_payload(context, inbound_payload) + + mock_report = SimpleNamespace(num_samples=8) + async_mock = AsyncMock(return_value=(mock_report, {"weights": torch.ones(1)})) + + with patch.object( + feddf_client.DefaultTrainingStrategy, + "train", + new=async_mock, + ) as mock_train, patch.object( + feddf_client.time, + "perf_counter", + side_effect=[10.0, 10.25], + ): + report, payload = asyncio.run(strategy.train(context)) + + mock_train.assert_awaited_once() + assert report is mock_report + assert getattr(report, "training_time") == 0.25 + assert getattr(report, "feddf_proxy_logits_time") == 0.25 + assert getattr(report, "payload_type") == "feddf_logits" + assert getattr(report, "proxy_size") == 3 + assert "logits" in payload + assert "weights" not in payload + assert tuple(payload["logits"].shape) == (3, 2) + assert loaded_weights == [inbound_payload["weights"]] + assert torch.equal(context.state["feddf_proxy_inputs"], proxy_inputs) + assert context.state["feddf_proxy_logits_time"] == 0.25 diff --git a/tests/datasources/test_torchvision_datasource.py b/tests/datasources/test_torchvision_datasource.py index fbd52ed23..ada748dd3 100644 --- a/tests/datasources/test_torchvision_datasource.py +++ b/tests/datasources/test_torchvision_datasource.py @@ -189,6 +189,65 @@ def __getitem__(self, index): assert datasource.classes() == ["neg", "pos"] +def test_torchvision_datasource_supports_deterministic_non_overlapping_subsets( + monkeypatch, tmp_path +): + """Subset configs should carve deterministic, disjoint slices from one split.""" + + class DummyBoolDataset: + def __init__( + self, + root, + train=True, + download=False, + transform=None, + ): + self.root = root + self.train = train + self.download = download + self.transform = transform + self.targets = list(range(10)) + self.classes = ("neg", "pos") + self.data = list(range(10)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index], self.targets[index] + + stub_datasets = types.SimpleNamespace(DummyBoolDataset=DummyBoolDataset) + dummy_config = _build_config( + tmp_path, + { + "datasource": "Torchvision", + "dataset_name": "DummyBoolDataset", + "download": False, + "unlabeled_split": "train", + "train_subset": {"seed": 7, "start": 2, "size": 4}, + "unlabeled_subset": {"seed": 7, "start": 0, "size": 2}, + }, + ) + + monkeypatch.setattr(torchvision_ds, "datasets", stub_datasets) + monkeypatch.setattr(torchvision_ds, "transforms", _StubTransforms()) + monkeypatch.setattr(torchvision_ds, "Config", lambda: dummy_config) + + datasource = torchvision_ds.DataSource() + + expected_indices = torch.randperm( + 10, generator=torch.Generator().manual_seed(7) + ).tolist() + assert datasource.trainset.indices == expected_indices[2:6] + assert datasource.get_unlabeled_set().indices == expected_indices[:2] + assert set(datasource.trainset.indices).isdisjoint( + datasource.get_unlabeled_set().indices + ) + assert datasource.targets() == [3, 4, 1, 7] + assert datasource.get_unlabeled_set().targets == [5, 0] + assert datasource.classes() == ["neg", "pos"] + + def test_torchvision_datasource_celeba_defaults(monkeypatch, tmp_path): """CelebA should inherit legacy defaults including target handling.""" diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py new file mode 100644 index 000000000..edf382b1d --- /dev/null +++ b/tests/servers/test_feddf_server_strategy.py @@ -0,0 +1,226 @@ +"""Tests for the FedDF example aggregation strategy.""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +import torch +import torch.nn.functional as F +from torch.utils.data import TensorDataset + +from plato.config import Config + +_TESTS_ROOT = Path(__file__).resolve().parent +_FEDDF_DIR = ( + _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" +) +if str(_FEDDF_DIR) not in sys.path: + sys.path.insert(0, str(_FEDDF_DIR)) + + +def _load_module(module_name: str, path: Path) -> Any: + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None: + raise RuntimeError(f"Unable to load spec for {path}") + + module = cast(Any, importlib.util.module_from_spec(spec)) + loader = spec.loader + if loader is None: + raise RuntimeError(f"Loader missing for {path}") + + loader.exec_module(module) + return module + + +feddf_algorithm = _load_module( + "feddf_algorithm_module", + _FEDDF_DIR / "feddf_algorithm.py", +) +feddf_server_strategy = _load_module( + "feddf_server_strategy_module", + _FEDDF_DIR / "feddf_server_strategy.py", +) + + +class TinyStudent(torch.nn.Module): + """Small student model used to verify server-side distillation.""" + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2, bias=False) + with torch.no_grad(): + self.linear.weight.zero_() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.linear(inputs) + + +class SharedProxyDatasource: + """Datasource stub exposing both unlabeled and test splits for FedDF.""" + + def __init__( + self, + proxy_inputs: torch.Tensor, + test_inputs: torch.Tensor | None = None, + ) -> None: + self._unlabeled = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) + self._test = TensorDataset( + test_inputs if test_inputs is not None else proxy_inputs, + torch.zeros(len(test_inputs) if test_inputs is not None else len(proxy_inputs)), + ) + + def get_unlabeled_set(self): + return self._unlabeled + + def get_test_set(self): + return self._test + + +def test_feddf_server_process_reports_distills_global_model(temp_config): + """FedDF should consume logits payloads and update the global model.""" + feddf_server = _load_module( + "feddf_server_module", + _FEDDF_DIR / "feddf_server.py", + ) + + Config().server.do_test = False + + trainer = SimpleNamespace(model=TinyStudent(), device="cpu") + algorithm = feddf_algorithm.Algorithm(trainer=trainer) + strategy = feddf_server_strategy.FedDFAggregationStrategy( + proxy_set_size=4, + proxy_seed=1, + temperature=1.0, + distillation_epochs=80, + distillation_batch_size=2, + distillation_learning_rate=0.4, + ) + proxy_inputs = torch.tensor( + [ + [2.0, 0.0], + [0.0, 2.0], + [1.5, 0.2], + [0.2, 1.5], + ] + ) + held_out_test_inputs = torch.tensor( + [ + [9.0, 9.0], + [8.0, 8.0], + [7.0, 7.0], + [6.0, 6.0], + ] + ) + server = feddf_server.Server( + aggregation_strategy=strategy, + datasource=lambda: SharedProxyDatasource(proxy_inputs, held_out_test_inputs), + ) + server.algorithm = algorithm + server.trainer = trainer + server.context.server = server + server.context.algorithm = algorithm + server.context.trainer = trainer + server.context.state["prng_state"] = None + server_payload = server.customize_server_payload(algorithm.extract_weights()) + assert set(server_payload.keys()) == {"weights", "proxy_inputs"} + assert torch.equal(server_payload["proxy_inputs"], proxy_inputs) + + teacher_logits_a = torch.tensor( + [ + [7.0, -7.0], + [-7.0, 7.0], + [5.5, -5.5], + [-5.5, 5.5], + ] + ) + teacher_logits_b = torch.tensor( + [ + [6.0, -6.0], + [-6.0, 6.0], + [4.5, -4.5], + [-4.5, 4.5], + ] + ) + + server.updates = [ + SimpleNamespace( + client_id=1, + report=SimpleNamespace( + num_samples=1, + accuracy=0.0, + processing_time=0.0, + comm_time=0.0, + training_time=0.0, + ), + payload={"logits": teacher_logits_a}, + ), + SimpleNamespace( + client_id=2, + report=SimpleNamespace( + num_samples=1, + accuracy=0.0, + processing_time=0.0, + comm_time=0.0, + training_time=0.0, + ), + payload={"logits": teacher_logits_b}, + ), + ] + + baseline_log_probs = torch.log_softmax(torch.zeros_like(teacher_logits_a), dim=1) + teacher_targets = torch.softmax((teacher_logits_a + teacher_logits_b) / 2, dim=1) + baseline_loss = F.kl_div( + baseline_log_probs, + teacher_targets, + reduction="batchmean", + ) + + asyncio.run(server._process_reports()) + + updated_weights = algorithm.extract_weights() + assert not torch.allclose( + updated_weights["linear.weight"], + torch.zeros_like(updated_weights["linear.weight"]), + ) + + with torch.no_grad(): + updated_logits = trainer.model(proxy_inputs) + updated_log_probs = torch.log_softmax(updated_logits, dim=1) + + distilled_loss = F.kl_div( + updated_log_probs, + teacher_targets, + reduction="batchmean", + ) + assert distilled_loss < baseline_loss + assert server.feddf_server_distillation_time > 0 + + logged_items = server.get_logged_items() + assert logged_items["feddf_server_distillation_time"] == ( + server.feddf_server_distillation_time + ) + assert logged_items["round_time"] == server.feddf_server_distillation_time + assert logged_items["elapsed_time"] >= server.feddf_server_distillation_time + + +def test_feddf_teacher_logits_average_uniformly_by_default(temp_config): + """FedDF should use uniform AVGLOGITS unless configured otherwise.""" + updates = [ + SimpleNamespace(report=SimpleNamespace(num_samples=1)), + SimpleNamespace(report=SimpleNamespace(num_samples=99)), + ] + teacher_logits_a = torch.tensor([[10.0, -10.0]]) + teacher_logits_b = torch.tensor([[-6.0, 6.0]]) + + aggregated = feddf_algorithm.Algorithm.aggregate_teacher_logits( + updates, + [{"logits": teacher_logits_a}, {"logits": teacher_logits_b}], + ) + + expected = (teacher_logits_a + teacher_logits_b) / 2 + assert torch.allclose(aggregated, expected)