Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d45f4b2
Document the DiLoCo implementation contract.
baochunli Apr 29, 2026
0378a57
Implemented DiLoCo outer aggregation.
baochunli Apr 29, 2026
6287b0f
Added exact local step limits for trainers.
baochunli Apr 29, 2026
b91d97d
Fixed local step counting for accumulation.
baochunli Apr 29, 2026
6846f50
Added DiLoCo parameter eligibility policy.
baochunli Apr 29, 2026
d122831
Handled adapter payload names in DiLoCo eligibility.
baochunli Apr 29, 2026
a9d8f3b
Avoided DiLoCo adapter alias overmatching.
baochunli Apr 29, 2026
679c9f6
Persisted in-process optimizer state for DiLoCo.
baochunli Apr 29, 2026
114f2a7
Wired DiLoCo server selection.
baochunli Apr 29, 2026
da0f341
Persisted optimizer state across train subprocesses.
baochunli Apr 29, 2026
c730732
Hardened subprocess optimizer state handoff.
baochunli Apr 29, 2026
33129d0
Added DiLoCo payload safety coverage.
baochunli Apr 29, 2026
21bf980
Added round-aware local-step sampling.
baochunli Apr 29, 2026
f6e8196
Handled non-materializable local-step samplers.
baochunli Apr 29, 2026
711fdb1
Added exact DiLoCo smoke configuration.
baochunli Apr 29, 2026
082aaf1
Added end-to-end DiLoCo validation coverage.
baochunli Apr 29, 2026
1313d30
Restored optimizer state after moving models to device.
baochunli Apr 29, 2026
f359d78
Logged DiLoCo outer optimizer application.
baochunli Apr 30, 2026
c365751
Added DiLoCo comparison configs and step-based scheduling.
baochunli Apr 30, 2026
bcc8073
Added MNIST DiLoCo comparison configs.
baochunli Apr 30, 2026
6ffb475
Aligned DiLoCo comparison budgets.
baochunli Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions configs/CIFAR10/diloco_resnet18.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
type = "diloco"
address = "127.0.0.1"
port = 8021

[server.diloco]
outer_optimizer = "nesterov"
outer_learning_rate = 0.7
outer_momentum = 0.9
aggregation_weighting = "uniform"
apply_outer_optimizer_to = "parameters"

[data]

# The training and testing dataset
datasource = "Torchvision"
dataset_name = "CIFAR10"
download = true

# Number of samples in each partition
partition_size = 1000

# IID or non-IID?
sampler = "iid"

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.9

# Number of local optimizer steps per DiLoCo synchronization.
local_steps_per_round = 500
preserve_optimizer_state = true

# DiLoCo paper inner-optimizer settings.
epochs = 5
batch_size = 10
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

# The machine learning model
model_name = "resnet_18"

[algorithm]

# Weight extraction and model update path reused by DiLoCo.
type = "fedavg"

[parameters]

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
68 changes: 68 additions & 0 deletions configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
address = "127.0.0.1"
port = 8022

[data]

# The training and testing dataset
datasource = "Torchvision"
dataset_name = "CIFAR10"
download = true

# Number of samples in each partition
partition_size = 1000

# IID or non-IID?
sampler = "iid"

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.9

# Match the original FedAvg local training shape while keeping 500 optimizer
# steps per round, equal to DiLoCo's H.
epochs = 5
batch_size = 10
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

# The machine learning model
model_name = "resnet_18"

[algorithm]

# Aggregation algorithm
type = "fedavg"

[parameters]

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
75 changes: 75 additions & 0 deletions configs/MNIST/diloco_lenet5.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
type = "diloco"
address = "127.0.0.1"
port = 8001
random_seed = 1
simulate_wall_time = true

[server.diloco]
outer_optimizer = "nesterov"
outer_learning_rate = 0.7
outer_momentum = 0.9
aggregation_weighting = "uniform"
apply_outer_optimizer_to = "parameters"

[data]
include = "mnist_iid.toml"
partition_size = 1000

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 20

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.99

# The machine learning model
model_name = "lenet5"

# Number of local optimizer steps per DiLoCo synchronization.
local_steps_per_round = 500
preserve_optimizer_state = true

# DiLoCo paper inner-optimizer settings.
epochs = 5
batch_size = 32
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

[algorithm]

# Weight extraction and model update path reused by DiLoCo.
type = "fedavg"

[parameters]

[parameters.model]
num_classes = 10

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
66 changes: 66 additions & 0 deletions configs/MNIST/fedavg_lenet5_diloco_comparison.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
[clients]

# Type
type = "simple"

# The total number of clients
total_clients = 50

# The number of clients selected in each round
per_round = 50

# Should the clients compute test accuracy locally?
do_test = false

[server]
address = "127.0.0.1"
port = 8002
random_seed = 1
simulate_wall_time = true

[data]
include = "mnist_iid.toml"
partition_size = 1000

[trainer]

# The type of the trainer
type = "basic"

# The maximum number of training rounds
rounds = 63

# The maximum number of clients running concurrently
max_concurrency = 7

# The target accuracy
target_accuracy = 0.99

# The machine learning model
model_name = "lenet5"

# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run.
# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per
# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching
# DiLoCo's 20 * H=500 = 10,000-step total budget.
epochs = 5
batch_size = 32
optimizer = "AdamW"
lr_scheduler = "LambdaLR"

[algorithm]

# Aggregation algorithm
type = "fedavg"

[parameters]

[parameters.model]
num_classes = 10

[parameters.optimizer]
lr = 0.0004
weight_decay = 0.1

[parameters.learning_rate]
warmup_steps = "1000it"
32 changes: 32 additions & 0 deletions docs/docs/configurations/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- `fedavg_personalized` a Federated Averaging server that supports all-purpose personalized federated learning by controlling when and which group of clients are to perform local personalization.
- `fedavg_mpc_additive` a Federated Averaging server that reconstructs additive MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_additive` processor.
- `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor.
- `diloco` a FedAvg-compatible server that applies DiLoCo outer aggregation. Use it with `algorithm.type = "fedavg"` and configure the outer optimizer under `[server.diloco]`.
- `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use.
- `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset.
- `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights.
Expand Down Expand Up @@ -124,6 +125,37 @@

Default value: `100`

!!! example "diloco"
Settings for `server.type = "diloco"`. DiLoCo reuses `algorithm.type = "fedavg"` for client weight extraction and global model loading, while the DiLoCo server turns client deltas into an outer-gradient update.

```toml
[server]
type = "diloco"

[algorithm]
type = "fedavg"

[server.diloco]
outer_optimizer = "nesterov"
outer_learning_rate = 0.7
outer_momentum = 0.9
aggregation_weighting = "uniform"
apply_outer_optimizer_to = "parameters"
```

`aggregation_weighting = "uniform"` matches balanced IID worker smoke runs. `aggregation_weighting = "num_samples"` matches Plato's traditional sample-weighted FedAvg behavior. With outer SGD and `outer_learning_rate = 1.0`, uniform weighting is equivalent to uniform model averaging; with `num_samples`, it is equivalent to Plato-style sample-weighted FedAvg.

`apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers.

Runnable comparison configurations are available for MNIST/LeNet and CIFAR-10/ResNet-18:

```bash
uv run python plato.py --config configs/MNIST/diloco_lenet5.toml
uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml
```

These configurations validate DiLoCo mechanics in Plato; they are not C4/model/pretraining reproductions of the DiLoCo paper.

!!! example "edge_downlink_bandwidth"
The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section).

Expand Down
14 changes: 14 additions & 0 deletions docs/docs/configurations/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@
!!! example "epochs"
The total number of epochs in local training in each communication round.

!!! example "local_steps_per_round"
The DiLoCo local work value `H`, counted as completed client-local optimizer steps between synchronizations.

`H` is not an epoch count, raw dataloader batch count, or gradient-accumulation micro-batch count. When gradient accumulation is enabled, only batches that trigger `optimizer.step()` increment `H`.

`H` may be smaller than one epoch. In that case, local training stops mid-epoch after exactly `H` optimizer steps while still running normal trainer cleanup, callback completion, state persistence, and reporting.

Small-`H` DiLoCo runs use round-aware sampling where supported so a logical client does not replay the same first `H` batches every round. Trainers or samplers that cannot count optimizer steps or advance the local stream faithfully must fail or warn clearly instead of silently approximating DiLoCo.

!!! example "preserve_optimizer_state"
Whether client-local optimizer and scheduler state should persist across a logical client's local train runs.

DiLoCo should set this to `true` with a stateful inner optimizer such as `AdamW`. Optimizer and scheduler state remains client-local and is not transmitted in client-server payloads.

!!! example "batch_size"
The size of the mini-batch of data in each step (iteration) of the training loop.

Expand Down
Loading
Loading