Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions docs/docs/examples/algorithms/1. Server Aggregation Algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,40 +56,42 @@ Key configuration parameters:

---

### MOON
### FedDF

MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level contrastive regularizer. Each client augments the shared model with a projection head, clones the incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints as negatives. The server still performs sample-weighted averaging but records a short history of global states for downstream analysis or warm restarts.
FedDF (Federated Distillation and Fusion) replaces direct parameter averaging with server-side distillation on a proxy set. The server selects a deterministic unlabeled proxy subset, ships those proxy inputs alongside the current global weights, and each client returns teacher logits on that shared payload instead of ordinary weight deltas. The server then aggregates those logits and distills their ensemble into the next global model using temperature-scaled soft targets.

```bash
cd examples/server_aggregation/moon/
uv run moon.py -c moon_MNIST_lenet5.toml
cd examples/server_aggregation/feddf/
uv run feddf.py -c feddf_MNIST_lenet5.toml
```

Key configuration parameters:

- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0).
- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5).
- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2).
- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`).
- `algorithm.proxy_set_size`: Number of unlabeled proxy samples used on the server for distillation.
- `algorithm.proxy_batch_size`: Batch size for iterating through the proxy set.
- `algorithm.proxy_seed`: Seed used to select the deterministic proxy subset shared by clients and server.
- `algorithm.temperature`: Softmax temperature used to smooth teacher logits before distillation.
- `algorithm.distillation_epochs`: Number of server-side distillation passes per round.
- `algorithm.distillation_batch_size`: Batch size for the distillation optimizer.
- `algorithm.learning_rate`: Learning rate for the server-side distillation optimizer.

**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021.
**Reference:** Tao Lin, Lingjing Kong, Sebastian U. Stich, Martin Jaggi. "[Ensemble Distillation for Robust Model Fusion in Federated Learning](https://arxiv.org/abs/2006.07242)," arXiv:2006.07242, 2020.

!!! note "Alignment with the paper"
Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON):
The module split follows the FedDF workflow directly: `feddf.py` stays as a thin launcher, `feddf_server.py` packages the current global weights together with the shared proxy inputs, `feddf_client.py` performs the standard local update and then emits logits on that server-supplied proxy payload, `feddf_server_strategy.py` resolves the deterministic proxy subset and routes the logits payload through direct weight aggregation, and `feddf_algorithm.py` encapsulates the weighted-logit ensemble plus the temperature-scaled KL distillation step.

- Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive- learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful—and helps the cosine similarities stay well behaved.
The configuration surface above mirrors the paper’s core knobs. `proxy_set_size`, `proxy_batch_size`, and `proxy_seed` control the deterministic unlabeled proxy data used for ensemble distillation, `temperature` shapes the softened teacher distribution, and `distillation_epochs`, `distillation_batch_size`, and `learning_rate` control the server-side student optimization that replaces direct averaging.

- Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities.

- Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available.
---

- Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only.
### MOON (client-training customization)

- Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons.
MOON is included in Plato as a **client-training customization** rather than as a server-aggregation
rule. The server still performs sample-weighted FedAvg; the distinguishing mechanism is the
contrastive local objective together with the historical-model buffer maintained by each client.

The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities
(`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper-
parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference.
See **5. Algorithms with Customized Client Training Loops** for the runnable example, key
configuration parameters, and the implementation alignment notes for MOON.

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,45 @@ uv run feddyn/feddyn.py -c feddyn/feddyn_MNIST_lenet5.toml

---

### MOON

MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level
contrastive regularizer. Each client augments the shared model with a projection head, clones the
incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints
as negatives. The server still performs sample-weighted averaging but records a short history of
global states for downstream analysis or warm restarts.

```bash
cd examples/server_aggregation/moon/
uv run moon.py -c moon_MNIST_lenet5.toml
```

Key configuration parameters:

- `algorithm.mu`: Weight assigned to the contrastive term (default: 5.0).
- `algorithm.temperature`: Softmax temperature applied to cosine similarities (default: 0.5).
- `algorithm.history_size`: Number of historical local models cached per client as negatives (default: 2).
- `trainer.model_name`: Name used for checkpointing the projection-ready backbone (default: `moon_lenet5`).

**Reference:** Qinbin Li, Bingsheng He, Dawn Song. “[Model-Contrastive Federated Learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf),” in Proc. CVPR, 2021.

!!! note "Alignment with the paper"
Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ [reference implementation](https://github.com/Xtra-Computing/MOON):

- Projection head & representations – `moon_model.py:31-79` implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive-learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful and helps the cosine similarities stay well behaved.

- Local training objective – `moon_trainer.py:26-152` combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same \\(\\mu\\) and \\(\\tau\\) hyper-parameters exposed in the config (`moon_MNIST_lenet5.toml:41-45`). This mirrors `train_net_fedcon` in the reference implementation, which also weights the contrastive term by \\(\\mu\\) and uses CrossEntropy on logits built from cosine similarities.

- Historical model buffer – the client keeps a FIFO queue of past local checkpoints (`moon_client.py:21-64`), equivalent to `model_buffer_size` in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available.

- Server aggregation – the server still performs sample-weighted FedAvg (`moon_server.py:12-35`, `moon_server_strategy.py:19-63`), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only.

- Shared architecture – `moon.py:8-15` now instantiates `MoonModel` once and passes it into both the client and server `(model=model)`. That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons.

The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities (`moon_model.py:76-79`), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper-parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference.

---

### FedMoS

FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance.
Expand Down
19 changes: 19 additions & 0 deletions examples/server_aggregation/feddf/feddf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Entry point for running the FedDF server aggregation example.
"""

from __future__ import annotations

import feddf_client
import feddf_server


def main():
"""Launch a Plato training session with the FedDF algorithm."""
client = feddf_client.create_client()
server = feddf_server.Server()
server.run(client)


if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions examples/server_aggregation/feddf/feddf_MNIST_lenet5.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[clients]
type = "simple"
total_clients = 100
per_round = 10
random_seed = 1
do_test = false
speed_simulation = true
sleep_simulation = true
avg_training_time = 20

[clients.simulation_distribution]
distribution = "normal"
mean = 10
sd = 3

[server]
address = "127.0.0.1"
port = 8000
synchronous = false
simulate_wall_time = true
minimum_clients_aggregated = 6
staleness_bound = 10
random_seed = 1

[data]
datasource = "Torchvision"
dataset_name = "MNIST"
download = true
data_path = "data"
partition_size = 200
sampler = "noniid"
concentration = 0.5
random_seed = 1

[trainer]
rounds = 20
type = "basic"
max_concurrency = 4
target_accuracy = 1.0
epochs = 5
batch_size = 64
optimizer = "SGD"
model_name = "lenet5"

[algorithm]
type = "fedavg"
proxy_set_size = 2048
proxy_batch_size = 128
proxy_seed = 1
temperature = 2.0
distillation_epochs = 5
distillation_batch_size = 64
learning_rate = 0.001

[parameters]

[parameters.optimizer]
lr = 0.01
momentum = 0.9
weight_decay = 0.0

[results]
result_path = "results/MNIST_lenet5/FedDF"
types = "round, accuracy, elapsed_time, comm_time, round_time"
154 changes: 154 additions & 0 deletions examples/server_aggregation/feddf/feddf_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""FedDF-specific helpers for ensemble distillation on the server."""

from __future__ import annotations

from collections import OrderedDict
from collections.abc import Mapping, Sequence

import torch
import torch.nn.functional as F
from feddf_utils import extract_batch_inputs, unwrap_model_outputs
from torch.utils.data import DataLoader, Dataset, TensorDataset

from plato.algorithms import fedavg


class Algorithm(fedavg.Algorithm):
"""Algorithm helpers for aggregating logits and distilling the student."""

@staticmethod
def aggregate_teacher_logits(
updates,
payloads: Sequence[Mapping[str, torch.Tensor]],
*,
weighting: str = "uniform",
) -> torch.Tensor:
"""Compute the ensembled teacher logits for AVGLOGITS distillation."""
if not payloads:
raise ValueError("FedDF requires at least one logits payload.")

first_logits = payloads[0].get("logits")
if not isinstance(first_logits, torch.Tensor):
raise TypeError("FedDF payloads must include a 'logits' tensor.")

weighting_name = weighting.strip().lower()
if weighting_name not in {"uniform", "samples"}:
raise ValueError(
"FedDF teacher weighting must be either 'uniform' or 'samples'."
)

total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates)
use_uniform_average = weighting_name == "uniform" or total_samples <= 0

aggregated = torch.zeros_like(first_logits, dtype=torch.float32)

for update, payload in zip(updates, payloads):
logits = payload.get("logits")
if not isinstance(logits, torch.Tensor):
raise TypeError("FedDF payloads must include a 'logits' tensor.")
if logits.shape != first_logits.shape:
raise ValueError(
"FedDF client logits must share the same proxy-set shape."
)

if use_uniform_average:
weight = 1 / len(payloads)
else:
weight = getattr(update.report, "num_samples", 0) / total_samples

aggregated += logits.detach().float() * weight

return aggregated

def distill_weights(
self,
baseline_weights: Mapping[str, torch.Tensor],
teacher_logits: torch.Tensor,
proxy_dataset: Dataset,
*,
temperature: float,
distillation_epochs: int,
distillation_batch_size: int,
distillation_learning_rate: float,
distillation_optimizer_name: str,
use_cosine_annealing: bool,
shuffle_batches: bool,
) -> OrderedDict[str, torch.Tensor]:
"""Distill the server model on proxy inputs using ensemble logits."""
if len(proxy_dataset) != len(teacher_logits):
raise ValueError(
"FedDF proxy samples and teacher logits must have matching lengths."
)

trainer = self.require_trainer()
model = self.require_model()
device = torch.device(getattr(trainer, "device", "cpu"))

self.load_weights(baseline_weights)

inputs = []
for example in proxy_dataset:
inputs.append(extract_batch_inputs(example))

proxy_inputs = torch.stack(inputs)
distillation_dataset = TensorDataset(proxy_inputs, teacher_logits.detach().cpu())
dataloader = DataLoader(
distillation_dataset,
batch_size=distillation_batch_size,
shuffle=shuffle_batches,
)

was_training = model.training
model.to(device)
model.train()

optimizer_name = distillation_optimizer_name.strip().lower()
if optimizer_name == "adam":
optimizer = torch.optim.Adam(
model.parameters(),
lr=distillation_learning_rate,
)
elif optimizer_name == "sgd":
optimizer = torch.optim.SGD(
model.parameters(),
lr=distillation_learning_rate,
)
else:
raise ValueError(
"FedDF distillation optimizer must be either 'adam' or 'sgd'."
)

total_steps = max(distillation_epochs * len(dataloader), 1)
scheduler = None
if use_cosine_annealing:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=total_steps,
)

for _ in range(distillation_epochs):
for batch_inputs, batch_logits in dataloader:
batch_inputs = batch_inputs.to(device)
batch_logits = batch_logits.to(device)
teacher_probs = torch.softmax(batch_logits / temperature, dim=1)

optimizer.zero_grad()
student_logits = unwrap_model_outputs(model(batch_inputs))
student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
loss = (
F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
* temperature
* temperature
)
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()

if not was_training:
model.eval()

return OrderedDict(
(name, tensor.detach().cpu().clone())
for name, tensor in model.state_dict().items()
)
Loading
Loading