diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml new file mode 100644 index 000000000..ed407000c --- /dev/null +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -0,0 +1,79 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8021 + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.9 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 5 +batch_size = 10 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml new file mode 100644 index 000000000..26f32d0ce --- /dev/null +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -0,0 +1,68 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8022 + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.9 + +# Match the original FedAvg local training shape while keeping 500 optimizer +# steps per round, equal to DiLoCo's H. +epochs = 5 +batch_size = 10 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/diloco_lenet5.toml b/configs/MNIST/diloco_lenet5.toml new file mode 100644 index 000000000..53eff9305 --- /dev/null +++ b/configs/MNIST/diloco_lenet5.toml @@ -0,0 +1,75 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8001 +random_seed = 1 +simulate_wall_time = true + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.99 + +# The machine learning model +model_name = "lenet5" + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 5 +batch_size = 32 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml new file mode 100644 index 000000000..e223915bb --- /dev/null +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -0,0 +1,66 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8002 +random_seed = 1 +simulate_wall_time = true + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 63 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.99 + +# The machine learning model +model_name = "lenet5" + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per +# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching +# DiLoCo's 20 * H=500 = 10,000-step total budget. +epochs = 5 +batch_size = 32 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 2bb800237..cef578eb9 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -8,6 +8,7 @@ - `fedavg_personalized` a Federated Averaging server that supports all-purpose personalized federated learning by controlling when and which group of clients are to perform local personalization. - `fedavg_mpc_additive` a Federated Averaging server that reconstructs additive MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_additive` processor. - `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor. + - `diloco` a FedAvg-compatible server that applies DiLoCo outer aggregation. Use it with `algorithm.type = "fedavg"` and configure the outer optimizer under `[server.diloco]`. - `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use. - `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset. - `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights. @@ -124,6 +125,37 @@ Default value: `100` +!!! example "diloco" + Settings for `server.type = "diloco"`. DiLoCo reuses `algorithm.type = "fedavg"` for client weight extraction and global model loading, while the DiLoCo server turns client deltas into an outer-gradient update. + + ```toml + [server] + type = "diloco" + + [algorithm] + type = "fedavg" + + [server.diloco] + outer_optimizer = "nesterov" + outer_learning_rate = 0.7 + outer_momentum = 0.9 + aggregation_weighting = "uniform" + apply_outer_optimizer_to = "parameters" + ``` + + `aggregation_weighting = "uniform"` matches balanced IID worker smoke runs. `aggregation_weighting = "num_samples"` matches Plato's traditional sample-weighted FedAvg behavior. With outer SGD and `outer_learning_rate = 1.0`, uniform weighting is equivalent to uniform model averaging; with `num_samples`, it is equivalent to Plato-style sample-weighted FedAvg. + + `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. + + Runnable comparison configurations are available for MNIST/LeNet and CIFAR-10/ResNet-18: + + ```bash + uv run python plato.py --config configs/MNIST/diloco_lenet5.toml + uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml + ``` + + These configurations validate DiLoCo mechanics in Plato; they are not C4/model/pretraining reproductions of the DiLoCo paper. + !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/configurations/trainer.md b/docs/docs/configurations/trainer.md index de9eb715e..05fef2f2d 100644 --- a/docs/docs/configurations/trainer.md +++ b/docs/docs/configurations/trainer.md @@ -56,6 +56,20 @@ !!! example "epochs" The total number of epochs in local training in each communication round. +!!! example "local_steps_per_round" + The DiLoCo local work value `H`, counted as completed client-local optimizer steps between synchronizations. + + `H` is not an epoch count, raw dataloader batch count, or gradient-accumulation micro-batch count. When gradient accumulation is enabled, only batches that trigger `optimizer.step()` increment `H`. + + `H` may be smaller than one epoch. In that case, local training stops mid-epoch after exactly `H` optimizer steps while still running normal trainer cleanup, callback completion, state persistence, and reporting. + + Small-`H` DiLoCo runs use round-aware sampling where supported so a logical client does not replay the same first `H` batches every round. Trainers or samplers that cannot count optimizer steps or advance the local stream faithfully must fail or warn clearly instead of silently approximating DiLoCo. + +!!! example "preserve_optimizer_state" + Whether client-local optimizer and scheduler state should persist across a logical client's local train runs. + + DiLoCo should set this to `true` with a stateful inner optimizer such as `AdamW`. Optimizer and scheduler state remains client-local and is not transmitted in client-server payloads. + !!! example "batch_size" The size of the mini-batch of data in each step (iteration) of the training loop. diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md new file mode 100644 index 000000000..f4b8ce405 --- /dev/null +++ b/docs/docs/development/diloco.md @@ -0,0 +1,237 @@ +# DiLoCo Design Contract + +This note defines what Plato calls faithful DiLoCo in the current +implementation. + +Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo +training loop inside Plato's federated runtime. It does not mean reproducing +the paper's exact C4 dataset, model scale, tokenizer, hardware topology, +pretraining duration, or final benchmark numbers. + +## Example Configurations + +Plato includes MNIST/LeNet and CIFAR-10/ResNet-18 comparison configurations +for checking DiLoCo against matched FedAvg runs: + +```bash +uv run python plato.py --config configs/MNIST/diloco_lenet5.toml +uv run python plato.py --config configs/MNIST/fedavg_lenet5_diloco_comparison.toml +uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml +uv run python plato.py --config configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml +``` + +These examples validate Plato's DiLoCo mechanics without reproducing the C4 +dataset, tokenizer, language-model scale, hardware topology, pretraining +duration, or final benchmark numbers from the paper. + +## Algorithm Contract + +DiLoCo has two optimizer levels: + +- The client-local inner optimizer trains each selected logical client for + exactly `H` local optimizer steps between synchronizations. +- The server-side outer optimizer updates the global model from the averaged + outer gradient. + +Plato's FedAvg-style model delta is: + +```text +plato_delta = client_after - global_before +``` + +DiLoCo's outer gradient is: + +```text +outer_gradient = global_before - client_after = -plato_delta +``` + +The DiLoCo server must still return a Plato-compatible model delta because +`algorithm.update_weights()` adds the returned delta to the current global +model. For example, outer SGD with learning rate `1.0` returns the averaged +Plato delta and is equivalent to FedAvg only when the same averaging rule is +used. + +The outer optimizer runs on the server. Clients run only the inner optimizer +and send model weights or weight-equivalent updates. Client-local optimizer and +scheduler state persists per logical client and is never sent to the server. + +## Local Work `H` + +`H` means client-local optimizer steps between synchronizations. It is not: + +- epochs, +- raw dataloader batches, or +- gradient-accumulation micro-batches. + +When gradient accumulation is enabled, `H` counts completed optimizer steps. +Raw batches that do not trigger `optimizer.step()` do not increment `H`. + +`H` may be smaller than one epoch. Faithful DiLoCo must therefore stop local +training mid-epoch after exactly `H` optimizer steps. This early stop must +still run normal trainer cleanup, state persistence, callback completion, and +reporting paths. It must not perform an extra final optimizer step. + +Small-`H` training must not repeatedly replay the same first `H` batches only +because the train loader is recreated each round. The implementation must use +round-aware resampling or an equivalent persistent sampling stream so each +logical client's local data stream advances across rounds in a reproducible +way. + +## State Ownership + +Server-owned state: + +- the global model, +- outer optimizer momentum or other outer optimizer state, +- aggregation metadata needed to update the global model. + +Client-owned state: + +- inner optimizer state, such as AdamW first and second moments, +- scheduler state and global/local optimizer-step counters, +- sampler or dataloader stream position needed for small-`H` continuity. + +Client-owned optimizer and scheduler state must not appear in client-server +payloads. It must remain local to the logical client, including when training +uses subprocesses. + +## Parameter And Buffer Policy + +By default, the outer optimizer applies only to trainable floating parameters. +This matches the algorithm definition, which optimizes model parameters. + +Floating buffers, such as batch normalization running statistics, are +synchronized without outer momentum by default. They use the selected averaging +rule but do not receive server-side momentum or Nesterov treatment. + +Non-floating buffers use conservative FedAvg-style behavior, including casting +or rounding as needed to preserve the buffer's dtype-compatible semantics. + +The implementation may offer `apply_outer_optimizer_to = "all_floating"` for +experiments, but the default must remain `parameters`. + +## Configuration Contract + +The faithful initial mode uses these configuration names and defaults: + +```toml +[server] +type = "diloco" + +[algorithm] +type = "fedavg" + +[trainer] +local_steps_per_round = H +preserve_optimizer_state = true +optimizer = "AdamW" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" # or "num_samples" +apply_outer_optimizer_to = "parameters" # or "all_floating" +``` + +`algorithm.type = "fedavg"` is intentional. Plato should reuse the existing +FedAvg weight extraction, delta computation, and global model loading path, +while `server.type = "diloco"` selects the server-side DiLoCo aggregation and +outer optimizer behavior. + +`aggregation_weighting = "uniform"` matches the balanced worker setting most +closely. `aggregation_weighting = "num_samples"` matches Plato's traditional +sample-weighted FedAvg behavior. FedAvg equivalence for outer SGD with learning +rate `1.0` is valid only when both runs use the same weighting rule. + +Unsupported modes must fail clearly. They must not silently fall back to an +approximate DiLoCo variant. Examples include trainer backends that cannot count +local optimizer steps exactly, execution paths that cannot preserve +client-local optimizer and scheduler state, samplers that cannot advance the +small-`H` local data stream across rounds, or payload paths that would send +optimizer state to the server. Experimental combinations that are allowed but +not faithful must warn clearly. + +## Implementation Sequence + +Dependency graph: + +```text +D1 +|-- D2 --> D3 +|-- D4 --> D5 +|-- D6 --> D7 +|-- D8 --> D9 +`-- D10 --> D11 + +D3, D5, D7, D9, D11 --> D12 --> D13 +``` + +Tasks: + +```yaml +- id: D1 + depends_on: [] + task: Document the exact DiLoCo contract and unsupported modes. + +- id: D2 + depends_on: [D1] + task: Add red tests for server-side outer gradient sign, weighting, and + FedAvg equivalence under matching weighting. + +- id: D3 + depends_on: [D2] + task: Implement DiLoCo server aggregation and outer optimizer state for SGD, + momentum SGD, and Nesterov. + +- id: D4 + depends_on: [D1] + task: Add red tests for exact local optimizer-step counting and `H` smaller + than one epoch. + +- id: D5 + depends_on: [D4] + task: Implement `trainer.local_steps_per_round` with mid-epoch termination + after exactly `H` optimizer steps. + +- id: D6 + depends_on: [D1] + task: Add red tests for per-client optimizer and scheduler state + persistence. + +- id: D7 + depends_on: [D6] + task: Persist client-local optimizer and scheduler state without sending it + to the server. + +- id: D8 + depends_on: [D1] + task: Add red tests for round-aware small-`H` sampling. + +- id: D9 + depends_on: [D8] + task: Implement round-aware resampling or an equivalent persistent sampling + stream for each logical client. + +- id: D10 + depends_on: [D1] + task: Add red tests for parameter and buffer eligibility. + +- id: D11 + depends_on: [D10] + task: Implement the default trainable-parameter-only outer optimizer policy + and conservative buffer synchronization. + +- id: D12 + depends_on: [D3, D5, D7, D9, D11] + task: Wire exact DiLoCo configuration, examples, and user-facing + documentation. + +- id: D13 + depends_on: [D12] + task: Add end-to-end faithful-mode validation coverage. +``` + +Every implementation task should use red/green test-driven development. Add +the failing tests that describe the contract first, then implement the smallest +runtime change that makes those tests pass. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index fa5bc9908..f4177c8b4 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -88,7 +88,9 @@ nav: - Servers: references/servers.md - Trainers: references/trainers.md - Evaluators: references/evaluators.md - - Developer's Guide: development.md + - Developer's Guide: + - Overview: development.md + - DiLoCo Design Contract: development/diloco.md - Deployment Guide: deployment.md - Digital Research Alliance of Canada: ccdb.md - Miscellaneous Notes: misc.md diff --git a/plato/servers/diloco.py b/plato/servers/diloco.py new file mode 100644 index 000000000..bedbb5f05 --- /dev/null +++ b/plato/servers/diloco.py @@ -0,0 +1,50 @@ +"""FedAvg-compatible server using DiLoCo aggregation.""" + +from plato.config import Config +from plato.servers import fedavg +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy + + +class Server(fedavg.Server): + """Federated learning server with server-side DiLoCo outer aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = DiLoCoAggregationStrategy( + **self._aggregation_config() + ) + + super().__init__( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) + + @staticmethod + def _aggregation_config() -> dict: + """Read optional DiLoCo aggregation settings from [server.diloco].""" + config = getattr(Config().server, "diloco", None) + if config is None: + return {} + + keys = ( + "outer_optimizer", + "outer_learning_rate", + "outer_momentum", + "aggregation_weighting", + "apply_outer_optimizer_to", + ) + return {key: getattr(config, key) for key in keys if hasattr(config, key)} diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 5f862fc35..e357adc31 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -222,13 +222,20 @@ async def _process_reports(self): # Use delta aggregation (default path) # Computes the weight deltas by comparing the weights received with # the current global model weights - deltas_received = algorithm.compute_weight_deltas( - baseline_weights, weights_received + delta_updates, delta_weights_received = ( + self._weight_updates_and_payloads(self.updates, weights_received) + ) + deltas_received = ( + algorithm.compute_weight_deltas( + baseline_weights, delta_weights_received + ) + if delta_weights_received + else [] ) # Runs a framework-agnostic server aggregation algorithm, such as # the federated averaging algorithm logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid()) - deltas = await self.aggregate_deltas(self.updates, deltas_received) + deltas = await self.aggregate_deltas(delta_updates, deltas_received) # Updates the existing model weights from the provided deltas updated_weights = algorithm.update_weights(deltas) # Loads the new model weights @@ -299,6 +306,20 @@ def _should_prefer_weight_aggregation(self) -> bool: and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas ) + @staticmethod + def _weight_updates_and_payloads(updates, weights_received): + """Return update/payload pairs whose reports contain model weights.""" + delta_updates = [] + delta_weights_received = [] + + for update, weights in zip(updates, weights_received): + if getattr(update.report, "type", "weights") != "weights": + continue + delta_updates.append(update) + delta_weights_received.append(weights) + + return delta_updates, delta_weights_received + def clients_processed(self) -> None: """Additional work to be performed after client reports have been processed.""" diff --git a/plato/servers/registry.py b/plato/servers/registry.py index 2211b8972..b26465a8c 100644 --- a/plato/servers/registry.py +++ b/plato/servers/registry.py @@ -10,6 +10,7 @@ from plato.config import Config from plato.servers import ( + diloco, fedavg, fedavg_cs, fedavg_gan, @@ -30,6 +31,7 @@ registered_servers = { "fedavg": fedavg.Server, "fedavg_lora": fedavg.Server, + "diloco": diloco.Server, "fedavg_cross_silo": fedavg_cs.Server, "fedavg_gan": fedavg_gan.Server, "fedavg_personalized": fedavg_personalized.Server, diff --git a/plato/servers/strategies/aggregation/__init__.py b/plato/servers/strategies/aggregation/__init__.py index f44c420a4..fe4174402 100644 --- a/plato/servers/strategies/aggregation/__init__.py +++ b/plato/servers/strategies/aggregation/__init__.py @@ -4,6 +4,7 @@ Each strategy is defined in its own module for clarity. """ +from plato.servers.strategies.aggregation.diloco import DiLoCoAggregationStrategy from plato.servers.strategies.aggregation.fedasync import FedAsyncAggregationStrategy from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy from plato.servers.strategies.aggregation.fedbuff import FedBuffAggregationStrategy @@ -16,6 +17,7 @@ __all__ = [ "FedAvgAggregationStrategy", + "DiLoCoAggregationStrategy", "FedBuffAggregationStrategy", "FedNovaAggregationStrategy", "FedAsyncAggregationStrategy", diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py new file mode 100644 index 000000000..ddc2428fc --- /dev/null +++ b/plato/servers/strategies/aggregation/diloco.py @@ -0,0 +1,501 @@ +""" +DiLoCo aggregation strategy. + +The strategy consumes Plato-style client deltas (`client_after - global_before`), +converts them to DiLoCo outer gradients, and returns Plato-compatible server +deltas for `algorithm.update_weights()` to add to the global model. +""" + +from __future__ import annotations + +import asyncio +import copy +import logging +import numbers +from collections.abc import Callable, Mapping +from types import SimpleNamespace +from typing import Any, cast + +import numpy as np + +from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy +from plato.servers.strategies.base import ServerContext + +try: # pragma: no cover - optional dependency + import torch +except ImportError: # pragma: no cover + torch = cast(Any, None) + + +class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): + """Aggregate client deltas with a server-side DiLoCo outer optimizer.""" + + _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} + _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + _SUPPORTED_APPLY_POLICIES = {"parameters", "all_floating"} + + def __init__( + self, + outer_optimizer: str = "nesterov", + outer_learning_rate: float = 0.7, + outer_momentum: float = 0.9, + aggregation_weighting: str = "uniform", + apply_outer_optimizer_to: str = "parameters", + ): + super().__init__() + self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) + self.outer_learning_rate = self._validate_learning_rate( + outer_learning_rate + ) + self.outer_momentum = self._validate_momentum(outer_momentum) + self.aggregation_weighting = self._validate_weighting_mode( + aggregation_weighting + ) + self.apply_outer_optimizer_to = self._validate_apply_policy( + apply_outer_optimizer_to + ) + self.momentum_state: dict[str, Any] = {} + + async def aggregate_deltas( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + context: ServerContext, + ) -> dict: + """Aggregate deltas and apply the configured DiLoCo outer optimizer.""" + eligible = self._eligible_updates(updates, deltas_received) + if not eligible: + self._remove_stale_momentum(set()) + return self._empty_delta(context, self._first_delta(deltas_received)) + + weights = self._aggregation_weights(eligible) + if not weights: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta: Any = None + for (_, delta, _), weight in zip(eligible, weights): + avg_delta = self._accumulate_weighted(avg_delta, delta, weight, context) + await asyncio.sleep(0) + + if avg_delta is None: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) + optimizer_paths = self._outer_optimizer_paths(avg_delta, context) + server_delta, active_paths = self._apply_outer_optimizer( + avg_delta, optimizer_paths + ) + logging.info( + "[Server] DiLoCo outer optimizer applied: optimizer=%s " + "outer_lr=%g outer_momentum=%g weighting=%s apply_to=%s " + "eligible_updates=%d optimized_tensors=%d.", + self.outer_optimizer, + self.outer_learning_rate, + self.outer_momentum, + self.aggregation_weighting, + self.apply_outer_optimizer_to, + len(eligible), + len(optimizer_paths), + ) + self._remove_stale_momentum(active_paths) + + return self._match_reference_structure(server_delta, eligible[0][1]) + + @classmethod + def _validate_outer_optimizer(cls, value: str) -> str: + optimizer = str(value).lower() + if optimizer not in cls._SUPPORTED_OPTIMIZERS: + supported = ", ".join(sorted(cls._SUPPORTED_OPTIMIZERS)) + raise ValueError( + f"Invalid outer_optimizer '{value}'. Supported values: {supported}." + ) + return optimizer + + @staticmethod + def _validate_learning_rate(value: float) -> float: + learning_rate = float(value) + if learning_rate < 0: + raise ValueError("outer_learning_rate must be nonnegative.") + return learning_rate + + @staticmethod + def _validate_momentum(value: float) -> float: + momentum = float(value) + if not 0 <= momentum < 1: + raise ValueError("outer_momentum must be in the range [0, 1).") + return momentum + + @classmethod + def _validate_weighting_mode(cls, value: str) -> str: + weighting = str(value).lower() + if weighting not in cls._SUPPORTED_WEIGHTING_MODES: + supported = ", ".join(sorted(cls._SUPPORTED_WEIGHTING_MODES)) + raise ValueError( + "Invalid aggregation_weighting " + f"'{value}'. Supported values: {supported}." + ) + return weighting + + @classmethod + def _validate_apply_policy(cls, value: str) -> str: + policy = str(value).lower() + if policy not in cls._SUPPORTED_APPLY_POLICIES: + supported = ", ".join(sorted(cls._SUPPORTED_APPLY_POLICIES)) + raise ValueError( + "Invalid apply_outer_optimizer_to " + f"'{value}'. Supported values: {supported}." + ) + return policy + + def _eligible_updates( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + ) -> list[tuple[SimpleNamespace, dict, float]]: + eligible: list[tuple[SimpleNamespace, dict, float]] = [] + for update, delta in zip(updates, deltas_received): + if getattr(update.report, "type", "weights") == "features": + continue + + num_samples = self._num_samples(update) + if num_samples <= 0: + continue + + eligible.append((update, delta, num_samples)) + + return eligible + + @staticmethod + def _num_samples(update: SimpleNamespace) -> float: + try: + return float(update.report.num_samples) + except (AttributeError, TypeError, ValueError): + return 0.0 + + def _aggregation_weights( + self, eligible: list[tuple[SimpleNamespace, dict, float]] + ) -> list[float]: + if not eligible: + return [] + + if self.aggregation_weighting == "uniform": + return [1.0 / len(eligible)] * len(eligible) + + total_samples = sum(num_samples for _, _, num_samples in eligible) + if total_samples <= 0: + return [] + + return [num_samples / total_samples for _, _, num_samples in eligible] + + def _outer_optimizer_paths( + self, avg_delta: Any, context: ServerContext + ) -> set[str]: + if self.apply_outer_optimizer_to == "all_floating": + return self._floating_leaf_paths(avg_delta) + + floating_paths = self._floating_leaf_paths(avg_delta) + trainable_parameter_names = self._trainable_parameter_names( + context, floating_paths + ) + return floating_paths.intersection(trainable_parameter_names) + + def _apply_outer_optimizer( + self, avg_delta: Any, optimizer_paths: set[str] + ) -> tuple[Any, set[str]]: + active_paths: set[str] = set() + + server_delta = self._map_tree( + avg_delta, + lambda value, path: self._apply_outer_optimizer_leaf( + value, path, optimizer_paths, active_paths + ), + ) + return server_delta, active_paths + + def _apply_outer_optimizer_leaf( + self, + avg_delta: Any, + path: str, + optimizer_paths: set[str], + active_paths: set[str], + ) -> Any: + if path not in optimizer_paths: + return avg_delta + + outer_gradient = self._scale_tree(avg_delta, -1.0) + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate) + + return self._apply_momentum_leaf(outer_gradient, path, active_paths) + + def _apply_momentum_leaf( + self, outer_gradient: Any, path: str, active_paths: set[str] + ) -> Any: + active_paths.add(path) + previous = self.momentum_state.get(path) + if previous is not None and not self._is_compatible(previous, outer_gradient): + previous = None + + if previous is None: + momentum = self._clone_tree(outer_gradient) + else: + momentum = self._add_values( + self._scale_tree(previous, self.outer_momentum), + outer_gradient, + ) + + self.momentum_state[path] = self._clone_tree(momentum) + + if self.outer_optimizer == "nesterov": + direction = self._add_values( + outer_gradient, + self._scale_tree(momentum, self.outer_momentum), + ) + else: + direction = momentum + + return self._scale_tree(direction, -self.outer_learning_rate) + + def _remove_stale_momentum(self, active_paths: set[str]) -> None: + if self.outer_optimizer == "sgd": + self.momentum_state.clear() + return + + for path in list(self.momentum_state): + if path not in active_paths: + del self.momentum_state[path] + + def _trainable_parameter_names( + self, context: ServerContext, payload_paths: set[str] | None = None + ) -> set[str]: + model = self._model_from_context(context) + adapter_names = self._adapter_names(model) + trainable_names: set[str] = set() + + for name, parameter in model.named_parameters(): + if getattr(parameter, "requires_grad", False) and self._is_floating_value( + parameter + ): + trainable_names.update( + self._payload_name_candidates(name, adapter_names, payload_paths) + ) + + return trainable_names + + @staticmethod + def _adapter_names(model: Any) -> set[str]: + adapter_names = {"default"} + + peft_config = getattr(model, "peft_config", None) + if isinstance(peft_config, Mapping): + adapter_names.update(str(name) for name in peft_config) + + active_adapter = getattr(model, "active_adapter", None) + if isinstance(active_adapter, str): + adapter_names.add(active_adapter) + + active_adapters = getattr(model, "active_adapters", None) + if callable(active_adapters): + try: + adapter_names.update(str(name) for name in active_adapters()) + except TypeError: + pass + elif isinstance(active_adapters, (list, tuple, set)): + adapter_names.update(str(name) for name in active_adapters) + + return adapter_names + + @classmethod + def _payload_name_candidates( + cls, + parameter_name: str, + adapter_names: set[str], + payload_paths: set[str] | None, + ) -> set[str]: + candidates = {parameter_name} + if payload_paths is not None and parameter_name in payload_paths: + return candidates + + parts = parameter_name.split(".") + for index, part in enumerate(parts): + if part not in adapter_names: + continue + + candidate = ".".join(parts[:index] + parts[index + 1 :]) + if payload_paths is None or candidate in payload_paths: + candidates.add(candidate) + + return candidates + + @staticmethod + def _model_from_context(context: ServerContext) -> Any: + trainer = getattr(context, "trainer", None) + model = getattr(trainer, "model", None) if trainer is not None else None + if model is None or not hasattr(model, "named_parameters"): + raise AttributeError( + "DiLoCo apply_outer_optimizer_to='parameters' requires " + "context.trainer.model with named_parameters()." + ) + return model + + def _floating_leaf_paths(self, value: Any) -> set[str]: + return self._collect_leaf_paths( + value, lambda leaf, _: self._is_floating_value(leaf) + ) + + def _collect_leaf_paths( + self, + value: Any, + predicate: Callable[[Any, str], bool], + path: str = "", + ) -> set[str]: + if isinstance(value, Mapping): + paths: set[str] = set() + for key, item in value.items(): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, key) + ) + ) + return paths + + if isinstance(value, list): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + if isinstance(value, tuple): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + return {path} if predicate(value, path) else set() + + @staticmethod + def _is_floating_value(value: Any) -> bool: + if torch is not None and isinstance(value, torch.Tensor): + return torch.is_floating_point(value) + + if isinstance(value, np.ndarray): + return np.issubdtype(value.dtype, np.floating) + + return isinstance(value, numbers.Real) and not isinstance( + value, (numbers.Integral, bool) + ) + + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: + zero_delta = self._zero_delta(context, reference_delta) + if zero_delta is not None: + return zero_delta + + if reference_delta is None: + return {} + + return self._scale_tree(reference_delta, 0.0) + + @staticmethod + def _first_delta(deltas_received: list[dict]) -> dict | None: + return deltas_received[0] if deltas_received else None + + def _map_tree(self, value: Any, leaf_fn: Callable[[Any, str], Any], path="") -> Any: + if isinstance(value, Mapping): + return { + key: self._map_tree(item, leaf_fn, self._join_path(path, key)) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ] + + if isinstance(value, tuple): + return tuple( + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ) + + return leaf_fn(value, path) + + def _scale_tree(self, value: Any, scalar: float) -> Any: + if isinstance(value, Mapping): + return { + key: self._scale_tree(item, scalar) for key, item in value.items() + } + + if isinstance(value, list): + return [self._scale_tree(item, scalar) for item in value] + + if isinstance(value, tuple): + return tuple(self._scale_tree(item, scalar) for item in value) + + return value * scalar + + @staticmethod + def _add_values(left: Any, right: Any) -> Any: + return left + right + + def _clone_tree(self, value: Any) -> Any: + if isinstance(value, Mapping): + return {key: self._clone_tree(item) for key, item in value.items()} + + if isinstance(value, list): + return [self._clone_tree(item) for item in value] + + if isinstance(value, tuple): + return tuple(self._clone_tree(item) for item in value) + + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().clone() + + if isinstance(value, np.ndarray): + return value.copy() + + try: + return copy.deepcopy(value) + except TypeError: + return value + + @staticmethod + def _is_compatible(left: Any, right: Any) -> bool: + if torch is not None and isinstance(left, torch.Tensor): + return ( + isinstance(right, torch.Tensor) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + if isinstance(left, np.ndarray): + return ( + isinstance(right, np.ndarray) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + left_shape = getattr(left, "shape", None) + right_shape = getattr(right, "shape", None) + if left_shape is not None or right_shape is not None: + return ( + left_shape == right_shape + and getattr(left, "dtype", None) == getattr(right, "dtype", None) + ) + + return isinstance(left, numbers.Number) and isinstance(right, numbers.Number) + + @staticmethod + def _join_path(prefix: str, key: Any) -> str: + key_text = str(key) + return key_text if not prefix else f"{prefix}.{key_text}" diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 1e98d1128..fc6feb886 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -168,6 +168,7 @@ def __init__( self.current_epoch = 0 self.training_start_time = time.time() self.model_state_dict = None + self._preserved_optimizer_states: dict[int, dict[str, Any]] = {} def _require_model(self) -> nn.Module: """Return the underlying model, ensuring it is available.""" @@ -177,6 +178,300 @@ def _require_model(self) -> nn.Module: ) return cast(nn.Module, self.model) + @staticmethod + def _local_steps_per_round(config: dict[str, Any]) -> int | None: + """Return the optional local optimizer-step limit for one train run.""" + value = config.get("local_steps_per_round") + if value is None: + return None + + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + raise ValueError( + "trainer.local_steps_per_round must be a positive integer." + ) + + return value + + def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> bool: + """Record one completed optimizer step and report whether H was reached.""" + if local_steps_per_round is None: + return False + + completed_steps = int(self.context.state.get("local_optimizer_steps", 0)) + 1 + self.context.state["local_optimizer_steps"] = completed_steps + return completed_steps >= local_steps_per_round + + @staticmethod + def _preserve_optimizer_state(config: dict[str, Any]) -> bool: + """Return whether optimizer state should survive local train runs.""" + return bool(config.get("preserve_optimizer_state", False)) + + @staticmethod + def _step_lr_scheduler_per_optimizer_step(config: dict[str, Any]) -> bool: + """Return whether LR scheduling should follow optimizer steps.""" + if config.get("local_steps_per_round") is None: + return False + + return getattr(Config().server, "type", None) == "diloco" + + def _step_lr_scheduler_after_optimizer_step( + self, step_lr_per_optimizer_step: bool + ) -> None: + """Advance step-based LR schedules after one completed optimizer step.""" + if step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + + @staticmethod + def _parameter_signature(name: str | None, parameter: torch.Tensor): + """Build a compatibility signature for one model parameter.""" + return (name, tuple(parameter.shape), str(parameter.dtype)) + + @classmethod + def _model_parameter_signature(cls, model: nn.Module): + """Return parameter names, shapes, dtypes, and order for a model.""" + return tuple( + cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + ) + + @classmethod + def _optimizer_parameter_signature( + cls, model: nn.Module, optimizer: torch.optim.Optimizer + ): + """Return optimizer parameter group ordering with model metadata.""" + named_parameters = { + id(parameter): cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + } + + group_signatures = [] + for group in optimizer.param_groups: + group_signatures.append( + tuple( + named_parameters.get( + id(parameter), + cls._parameter_signature(None, parameter), + ) + for parameter in group.get("params", []) + ) + ) + + return tuple(group_signatures) + + @staticmethod + def _scheduler_type(scheduler: Any | None) -> type | None: + """Return the scheduler type used for compatibility checks.""" + if scheduler is None: + return None + return type(scheduler) + + def _preserved_state_is_compatible( + self, + payload: dict[str, Any], + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Any | None, + ) -> bool: + """Return whether a cached optimizer bundle matches this train run.""" + if payload.get("optimizer_type") is not type(optimizer): + return False + + if payload.get("scheduler_type") is not self._scheduler_type(scheduler): + return False + + if payload.get("model_parameters") != self._model_parameter_signature(model): + return False + + if payload.get("optimizer_parameters") != self._optimizer_parameter_signature( + model, optimizer + ): + return False + + if not callable(getattr(optimizer, "load_state_dict", None)): + return False + + if payload.get("scheduler_state") is not None and not callable( + getattr(scheduler, "load_state_dict", None) + ): + return False + + return True + + def _restore_preserved_optimizer_state(self) -> None: + """Restore compatible optimizer and scheduler state for this client.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None or self.optimizer is None: + return + + model = self._require_model() + if not self._preserved_state_is_compatible( + payload, model, self.optimizer, self.lr_scheduler + ): + logging.info( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + return + + try: + scheduler_state = payload.get("scheduler_state") + if scheduler_state is not None: + self.lr_scheduler.load_state_dict(copy.deepcopy(scheduler_state)) + + self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) + except Exception as error: + logging.warning( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + error, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + self.optimizer = self.optimizer_strategy.create_optimizer( + model, self.context + ) + self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( + self.optimizer, self.context + ) + + def _save_preserved_optimizer_state(self) -> None: + """Save optimizer and scheduler state locally for this logical client.""" + if self.optimizer is None: + return + + model = self._require_model() + scheduler_state = None + if self.lr_scheduler is not None: + state_dict_fn = getattr(self.lr_scheduler, "state_dict", None) + if callable(state_dict_fn): + scheduler_state = copy.deepcopy(state_dict_fn()) + + self._preserved_optimizer_states[self.client_id] = { + "optimizer_type": type(self.optimizer), + "optimizer_state": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_type": self._scheduler_type(self.lr_scheduler), + "scheduler_state": scheduler_state, + "model_parameters": self._model_parameter_signature(model), + "optimizer_parameters": self._optimizer_parameter_signature( + model, self.optimizer + ), + } + + def _optimizer_state_filename(self, run_id: str) -> str: + """Return the local optimizer-state handoff filename.""" + model_name = Config().trainer.model_name + return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + + def _optimizer_state_output_filename(self, run_id: str) -> str: + """Return a unique subprocess optimizer-state output filename.""" + model_name = Config().trainer.model_name + token = time.time_ns() + return f"{model_name}_{self.client_id}_{run_id}_{os.getpid()}_{token}.optim.pkl" + + def _optimizer_state_path(self, filename: str) -> str: + """Return the local optimizer-state handoff path.""" + return os.path.join(Config().params["model_path"], filename) + + def _save_preserved_optimizer_state_file(self, filename: str) -> bool: + """Persist preserved optimizer state for subprocess handoff.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None: + return False + + model_path = Config().params["model_path"] + os.makedirs(model_path, exist_ok=True) + state_path = self._optimizer_state_path(filename) + tmp_path = f"{state_path}.{os.getpid()}.tmp" + + try: + with open(tmp_path, "wb") as state_file: + pickle.dump(copy.deepcopy(payload), state_file) + os.replace(tmp_path, state_path) + return True + except Exception as error: + if os.path.exists(tmp_path): + os.remove(tmp_path) + logging.warning( + "[Client #%d] Failed to persist optimizer state to %s: %s", + self.client_id, + state_path, + error, + ) + return False + + def _load_preserved_optimizer_state_file( + self, filename: str, *, clear_on_missing: bool = False + ) -> bool: + """Load preserved optimizer state from a subprocess handoff file.""" + state_path = self._optimizer_state_path(filename) + if not os.path.exists(state_path): + if clear_on_missing: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.info( + "[Client #%d] No persisted optimizer state found at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return False + + try: + with open(state_path, "rb") as state_file: + payload = pickle.load(state_file) + except Exception as error: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding unreadable optimizer state at %s; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + state_path, + error, + ) + return False + + if not isinstance(payload, dict): + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding invalid optimizer state at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return False + + self._preserved_optimizer_states[self.client_id] = payload + return True + + def _remove_preserved_optimizer_state_file(self, filename: str) -> None: + """Remove a local optimizer-state sidecar if it exists.""" + state_path = self._optimizer_state_path(filename) + try: + os.remove(state_path) + except FileNotFoundError: + return + except OSError as error: + logging.warning( + "[Client #%d] Failed to remove optimizer state at %s: %s", + self.client_id, + state_path, + error, + ) + + def _finish_subprocess_optimizer_state( + self, input_filename: str, output_filename: str + ) -> None: + """Load the child output sidecar and promote it for the next round.""" + loaded = self._load_preserved_optimizer_state_file( + output_filename, clear_on_missing=True + ) + if loaded: + self._save_preserved_optimizer_state_file(input_filename) + self._remove_preserved_optimizer_state_file(output_filename) + else: + self._remove_preserved_optimizer_state_file(input_filename) + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -384,8 +679,25 @@ def simulate_sleep_time(self): def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" + preserve_optimizer_state = self._preserve_optimizer_state(config) + if preserve_optimizer_state: + optimizer_state_filename = config.get( + "_optimizer_state_input_filename", + self._optimizer_state_filename(config["run_id"]), + ) + optimizer_state_output_filename = config.get( + "_optimizer_state_output_filename", + optimizer_state_filename, + ) + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) + self.train_model(config, trainset, sampler, **kwargs) + if preserve_optimizer_state: + self._save_preserved_optimizer_state_file(optimizer_state_output_filename) + model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" self.save_model(filename) @@ -397,6 +709,16 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + preserve_optimizer_state = self._preserve_optimizer_state(config) + if not preserve_optimizer_state: + self._preserved_optimizer_states.pop(self.client_id, None) + + local_steps_per_round = self._local_steps_per_round(config) + self.context.state["local_optimizer_steps"] = 0 + if local_steps_per_round is None: + self.context.state.pop("local_steps_per_round", None) + else: + self.context.state["local_steps_per_round"] = local_steps_per_round # Ensure training step strategy respects higher-order gradient settings if self.training_step_strategy is not None: @@ -476,24 +798,28 @@ def train_model(self, config, trainset, sampler, **kwargs): self.context.state["grad_accum_loss_total"] = 0.0 self.context.state["grad_accum_loss_count"] = 0 - # Create optimizer using strategy + # Move the model before optimizer state restore so PyTorch maps restored + # state tensors onto the same device as the optimizer parameters. model = self._require_model() + model.to(self.device) + model.train() + + # Create optimizer using strategy self.optimizer = self.optimizer_strategy.create_optimizer(model, self.context) # Create LR scheduler using strategy self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( self.optimizer, self.context ) - - # Move model to device - model = self._require_model() - model.to(self.device) - model.train() + if preserve_optimizer_state: + self._restore_preserved_optimizer_state() # Training epochs total_epochs = config["epochs"] + step_lr_per_optimizer_step = self._step_lr_scheduler_per_optimizer_step(config) tic = time.perf_counter() training_stop_requested = False + local_step_limit_reached = False try: total_batches = len(self.train_loader) except (TypeError, AttributeError): @@ -564,6 +890,12 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) # Strategy hook: after_step self.model_update_strategy.after_step(self.context) @@ -591,7 +923,7 @@ def compute_loss(outputs, labels_inner): ): self._handle_control_log() - if control_actions.get("stop_training"): + if control_actions.get("stop_training") or local_step_limit_reached: training_stop_requested = True break @@ -601,7 +933,11 @@ def compute_loss(outputs, labels_inner): finalize_loss = None finalize_step_done = False finalize_callable = getattr(self.training_step_strategy, "finalize", None) - if batches_seen and callable(finalize_callable): + if ( + batches_seen + and callable(finalize_callable) + and not local_step_limit_reached + ): finalize_loss = finalize_callable( model=model, optimizer=self.optimizer, @@ -613,6 +949,12 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) self.model_update_strategy.after_step(self.context) self.callback_handler.call_event( "on_train_step_end", @@ -652,11 +994,15 @@ def compute_loss(outputs, labels_inner): # No batches remain, but respect control flag. pass + if local_step_limit_reached: + training_stop_requested = True + self.context.state.pop("is_last_batch", None) self.context.state.pop("hf_optimizer_step_index", None) # LR scheduler step - self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + if not step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) # Handle optimizer params state update if needed if hasattr(self.optimizer, "params_state_update"): @@ -701,6 +1047,9 @@ def compute_loss(outputs, labels_inner): # Callbacks: train run end self.callback_handler.call_event("on_train_run_end", self, config) + if preserve_optimizer_state: + self._save_preserved_optimizer_state() + def train(self, trainset, sampler, **kwargs) -> float: """ The main training loop in a federated learning workload. @@ -721,6 +1070,21 @@ def train(self, trainset, sampler, **kwargs) -> float: if "max_concurrency" in config: tic = time.perf_counter() + preserve_optimizer_state = self._preserve_optimizer_state(config) + optimizer_state_filename = None + optimizer_state_output_filename = None + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename( + config["run_id"] + ) + optimizer_state_output_filename = self._optimizer_state_output_filename( + config["run_id"] + ) + config = { + **config, + "_optimizer_state_input_filename": optimizer_state_filename, + "_optimizer_state_output_filename": optimizer_state_output_filename, + } if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -773,6 +1137,14 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error + if ( + optimizer_state_filename is not None + and optimizer_state_output_filename is not None + ): + self._finish_subprocess_optimizer_state( + optimizer_state_filename, optimizer_state_output_filename + ) + toc = time.perf_counter() self.pause_training() else: diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 91e5a0482..c934e48b4 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -14,12 +14,26 @@ import torch import torch.utils.data +from plato.config import Config from plato.trainers.strategies.base import DataLoaderStrategy, TrainingContext CollateFn = Callable[[list[Any]], Any] AdjustFn = Callable[[TrainingContext], int] +class _FixedOrderSampler(torch.utils.data.Sampler): + """Sampler that yields precomputed dataset indices in order.""" + + def __init__(self, indices: list[int]): + self._indices = indices + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + + def _context_uses_cuda(context: TrainingContext) -> bool: """Return True if the training context targets a CUDA device.""" device = getattr(context, "device", None) @@ -40,6 +54,77 @@ def _resolve_pin_memory(setting: bool | None, context: TrainingContext) -> bool: return _context_uses_cuda(context) +def _local_step_stream_start( + context: TrainingContext, samples_per_round: int, stream_length: int +) -> int: + """Return the deterministic stream offset for this local-step round.""" + current_round = int(getattr(context, "current_round", 0) or 0) + if current_round > 0: + return ((current_round - 1) * samples_per_round) % stream_length + + offset = int(context.state.get("_local_step_sampler_stream_offset", 0)) + context.state["_local_step_sampler_stream_offset"] = offset + samples_per_round + return offset % stream_length + + +def _enforce_diloco_full_participation_for_local_steps() -> None: + """Require DiLoCo workers to train once per outer synchronization.""" + server_type = getattr(Config().server, "type", None) + if server_type != "diloco": + return + + total_clients = int(Config().clients.total_clients) + clients_per_round = int(Config().clients.per_round) + if clients_per_round == total_clients: + return + + raise ValueError( + "DiLoCo local-step data loading requires clients.per_round to equal " + "clients.total_clients so every worker advances its local data stream " + "once per outer round." + ) + + +def _apply_local_step_sampling_stream( + sampler_obj, batch_size: int, context: TrainingContext +): + """Advance deterministic samplers across short local-step rounds.""" + local_steps_per_round = context.state.get("local_steps_per_round") + if local_steps_per_round is None: + return sampler_obj + + _enforce_diloco_full_participation_for_local_steps() + + if sampler_obj is None: + return sampler_obj + + samples_per_round = int(local_steps_per_round) * int(batch_size) + if samples_per_round <= 0: + return sampler_obj + + try: + indices = list(iter(sampler_obj)) + except (TypeError, NotImplementedError): + logging.warning( + "Sampler %s cannot be materialized for round-aware local-step " + "sampling; using it unchanged. Consecutive short local rounds may " + "replay the same sampler prefix.", + type(sampler_obj), + ) + return sampler_obj + + if len(indices) == 0: + return sampler_obj + + start = _local_step_stream_start(context, samples_per_round, len(indices)) + if start == 0: + ordered_indices = indices + else: + ordered_indices = indices[start:] + indices[:start] + + return _FixedOrderSampler(ordered_indices) + + class DefaultDataLoaderStrategy(DataLoaderStrategy): """ Default data loader strategy. @@ -100,6 +185,10 @@ def create_train_loader( sampler_obj = None shuffle = self.shuffle + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + if sampler is None and not shuffle: logging.warning( "Data loader strategy received no sampler; falling back to SequentialSampler." @@ -174,6 +263,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -239,6 +332,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -320,6 +417,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, actual_batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=actual_batch_size, @@ -383,6 +484,10 @@ def create_train_loader( sampler_obj = None shuffle = True + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, diff --git a/plato/trainers/strategies/training_step.py b/plato/trainers/strategies/training_step.py index b4aba6d9d..5afa9702c 100644 --- a/plato/trainers/strategies/training_step.py +++ b/plato/trainers/strategies/training_step.py @@ -128,6 +128,9 @@ def training_step( if self.current_step % self.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False # Return unscaled loss for logging return loss diff --git a/tests/clients/test_simple_client.py b/tests/clients/test_simple_client.py index 43ffbb907..11b1bd8e0 100644 --- a/tests/clients/test_simple_client.py +++ b/tests/clients/test_simple_client.py @@ -1,7 +1,10 @@ """End-to-end smoke tests for the strategy-based client runtime.""" import asyncio +import pickle +import sys from dataclasses import dataclass +from pathlib import Path import torch from torch.utils.data import Dataset @@ -10,6 +13,20 @@ from plato.clients import simple from plato.config import Config from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import AdamWOptimizerStrategy, StepLRSchedulerStrategy +from tests.test_utils.fakes import NoOpCommunicationStrategy + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} class ToyDataset(Dataset): @@ -48,37 +65,155 @@ def get_test_set(self): return self._test -def _build_client(): +def _build_client(trainer=ComposableTrainer): """Instantiate a client wired with custom model, datasource, and trainer.""" return simple.Client( model=torch.nn.Linear(4, 2), datasource=ToyDatasource, - trainer=ComposableTrainer, + trainer=trainer, algorithm=lambda trainer: fedavg.Algorithm(trainer), ) -def test_simple_client_trains_with_default_strategies(temp_config): - """A simple client should complete one training round using the strategy stack.""" - Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) +def _build_stateful_trainer(model=None, callbacks=None): + """Build a trainer whose local optimizer and scheduler state is non-empty.""" + return ComposableTrainer( + model=model, + callbacks=callbacks, + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) - client = _build_client() - # Assign identifiers expected by the client runtime. +def _configure_one_round_client(client): + """Prepare a client for a deterministic single training round.""" client.client_id = 1 client._context.client_id = 1 client.current_round = 1 client._context.current_round = 1 - # Prepare data and runtime components. client._load_data() client.configure() client._allocate_data() + +def _disable_payload_processors(client): + """Keep the test focused on decoded client-server model payload contents.""" + client.inbound_processor = None + client.outbound_processor = None + client._context.inbound_processor = None + client._context.outbound_processor = None + + +def _assert_model_weight_payload(payload, model): + """Assert that an outbound payload contains exactly model state tensors.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + for name, expected in model_state.items(): + assert torch.equal(payload[name], expected) + + +def _assert_preserved_state_is_local(trainer, client_id): + """Assert optimizer and scheduler persistence exists only in trainer state.""" + state = trainer._preserved_optimizer_states[client_id] + + assert state["optimizer_state"]["state"] + assert state["scheduler_state"] is not None + assert state["scheduler_state"]["last_epoch"] >= 1 + assert state["scheduler_state"]["_step_count"] >= 2 + + +def test_simple_client_trains_with_default_strategies(temp_config): + """A simple client should complete one training round using the strategy stack.""" + Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) + + client = _build_client() + + _configure_one_round_client(client) + report, payload = asyncio.run(client._train()) assert report.client_id == 1 # With partition_size=4 each client receives four samples. assert report.num_samples == 4 - assert isinstance(payload, dict) - assert all(isinstance(value, torch.Tensor) for value in payload.values()) + _assert_model_weight_payload(payload, client.trainer.model) + + +def test_simple_client_payload_excludes_local_state_when_persistence_enabled( + temp_config, +): + """FedAvg/DiLoCo client payloads stay model-only with local persistence.""" + Config.params["run_id"] = "client-payload-in-process" + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + _assert_preserved_state_is_local(client.trainer, client.client_id) + _assert_model_weight_payload(sent_payload, client.trainer.model) + + +def test_simple_client_subprocess_payload_excludes_local_state_sidecar( + temp_config, monkeypatch, tmp_path +): + """Subprocess persistence uses a sidecar without changing server payloads.""" + model_path = Path(tmp_path) / "models" / "pretrained" + checkpoint_path = Path(tmp_path) / "checkpoints" + model_path.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(checkpoint_path) + Config.params["run_id"] = "client-payload-subprocess" + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + max_concurrency=1, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + state_path = ( + Path(Config.params["model_path"]) + / client.trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with state_path.open("rb") as state_file: + sidecar_state = pickle.load(state_file) + + _assert_preserved_state_is_local(client.trainer, client.client_id) + assert sidecar_state["optimizer_state"]["state"] + assert sidecar_state["scheduler_state"] is not None + _assert_model_weight_payload(sent_payload, client.trainer.model) diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 6dbc1fa08..e499655be 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -24,13 +24,14 @@ class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" def __init__(self, train_size: int = 4, test_size: int = 2): + generator = torch.Generator().manual_seed(13) self._train = TensorDataset( - torch.randn(train_size, 1, 28, 28), - torch.randint(0, 10, (train_size,)), + torch.randn(train_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (train_size,), generator=generator), ) self._test = TensorDataset( - torch.randn(test_size, 1, 28, 28), - torch.randint(0, 10, (test_size,)), + torch.randn(test_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (test_size,), generator=generator), ) def num_train_examples(self): @@ -96,7 +97,6 @@ def test_fedavg_lenet5_smoke(monkeypatch): async_run(server._process_reports()) assert server.accuracy >= 0 - @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3cb4a907b..4ff610756 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -107,6 +107,36 @@ def configure_environment(config_dict: dict): Config._instance = None +@contextlib.contextmanager +def configure_environment_from_path(config_path: Path): + """ + Context manager that initialises Config singleton from an existing config. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + Config._instance = None # reset singleton + Config.params = {} + + previous_env = os.environ.get("config_file") + previous_argv = sys.argv[:] + os.environ["config_file"] = str(config_path) + sys.argv = [ + previous_argv[0] if previous_argv else "pytest", + "--base", + tmp_dir, + ] + + try: + config = Config() + yield config + finally: + if previous_env is None: + os.environ.pop("config_file", None) + else: + os.environ["config_file"] = previous_env + sys.argv = previous_argv + Config._instance = None + + def async_run(coro): """Utility to execute the coroutine using asyncio.run (Python 3.7+).""" return asyncio.run(coro) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py new file mode 100644 index 000000000..4c739051d --- /dev/null +++ b/tests/servers/test_diloco_strategy.py @@ -0,0 +1,822 @@ +"""Tests for DiLoCo server-side outer aggregation.""" + +import asyncio +import logging +from types import SimpleNamespace + +import pytest +import torch + +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy +from plato.servers.strategies.base import ServerContext + + +class DummyAlgorithm: + """Minimal algorithm stub for zero-delta construction.""" + + def __init__(self, baseline): + self.baseline = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.baseline.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + return [ + { + name: weights[name] - baseline_weights[name] + for name in baseline_weights.keys() + } + for weights in weights_list + ] + + +class ServerAlgorithm(DummyAlgorithm): + """Algorithm stub for exercising FedAvg-compatible server dispatch.""" + + def __init__(self, baseline): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + self.delta_payloads = None + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.current.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + self.delta_payloads = weights_list + return super().compute_weight_deltas(baseline_weights, weights_list) + + def update_weights(self, deltas): + self.current = { + name: self.current[name] + deltas[name] for name in self.current + } + return self.extract_weights() + + def load_weights(self, weights): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in weights.items() + } + + +class RecordingDiLoCoStrategy(DiLoCoAggregationStrategy): + """DiLoCo strategy recording server dispatch calls.""" + + def __init__(self): + super().__init__( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + self.delta_calls = 0 + self.last_updates = None + self.last_deltas = None + + async def aggregate_deltas(self, updates, deltas_received, context): + self.delta_calls += 1 + self.last_updates = updates + self.last_deltas = deltas_received + return await super().aggregate_deltas(updates, deltas_received, context) + + +class MixedStateModel(torch.nn.Module): + """Model exposing trainable, frozen, floating-buffer, and integer state.""" + + def __init__(self): + super().__init__() + self.trainable = torch.nn.Parameter(torch.tensor([1.0])) + self.frozen = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=False) + self.register_buffer("floating_buffer", torch.tensor([1.0])) + self.register_buffer("integer_buffer", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) + + +class PeftLikeAdapterModel(torch.nn.Module): + """Model whose adapter payload keys omit PEFT's default adapter segment.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.base_model = torch.nn.Module() + self.base_model.model = torch.nn.Module() + self.base_model.model.linear = torch.nn.Module() + self.base_model.model.linear.lora_A = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + +class AdapterAliasCollisionModel(torch.nn.Module): + """Model with a trainable parameter and separate payload key collision.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.foo = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + +def _context(baseline=None, model=None): + context = ServerContext() + if baseline is not None: + context.algorithm = DummyAlgorithm(baseline) + if model is not None: + context.trainer = SimpleNamespace(model=model) + return context + + +def _update(num_samples, report_type="weights"): + return SimpleNamespace( + report=SimpleNamespace(num_samples=num_samples, type=report_type) + ) + + +def _server_update(payload, num_samples=1, report_type="weights"): + update = _update(num_samples, report_type) + update.client_id = len(str(payload)) + update.report.accuracy = 0.5 + update.report.processing_time = 0.1 + update.report.comm_time = 0.1 + update.report.training_time = 0.1 + update.payload = payload + return update + + +def _aggregate(strategy, updates, deltas, baseline=None, model=None): + return asyncio.run( + strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) + ) + + +def test_diloco_server_type_uses_fedavg_algorithm_and_strategy(temp_config): + """server.type=diloco should select a FedAvg-compatible DiLoCo server.""" + from plato.algorithms import registry as algorithms_registry + from plato.config import Config + from plato.servers import diloco as diloco_server + from plato.servers import fedavg + from plato.servers import registry as servers_registry + + Config().server.type = "diloco" + Config().algorithm.type = "fedavg" + Config().server.diloco = SimpleNamespace( + outer_optimizer="sgd", + outer_learning_rate=0.25, + outer_momentum=0.1, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server = servers_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server, fedavg.Server) + assert isinstance(server.aggregation_strategy, DiLoCoAggregationStrategy) + assert server.aggregation_strategy.outer_optimizer == "sgd" + assert server.aggregation_strategy.outer_learning_rate == 0.25 + assert server.aggregation_strategy.outer_momentum == 0.1 + assert server.aggregation_strategy.aggregation_weighting == "num_samples" + assert server.aggregation_strategy.apply_outer_optimizer_to == "all_floating" + assert Config().algorithm.type == "fedavg" + assert "diloco" not in algorithms_registry.registered_algorithms + + +def test_diloco_server_process_reports_uses_delta_aggregation(temp_config): + """DiLoCo server processing should reach the delta aggregation path.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [ + _server_update({"w": torch.tensor([2.0])}), + _server_update({"w": torch.tensor([4.0])}), + ] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert strategy.last_updates == server.updates + assert len(strategy.last_deltas) == 2 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([3.0])) + + +def test_diloco_server_does_not_use_inherited_weight_aggregation(temp_config): + """DiLoCo must not bypass delta aggregation via inherited FedAvg weights.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + + async def fail_if_called(*_args, **_kwargs): + raise AssertionError("Inherited aggregate_weights() must not be called.") + + strategy.aggregate_weights = fail_if_called + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [_server_update({"w": torch.tensor([2.0])})] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_diloco_server_filters_non_weight_reports_before_delta_computation( + temp_config, +): + """Non-weight payloads should not reach compute_weight_deltas().""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + weight_payload = {"w": torch.tensor([2.0])} + server.updates = [ + _server_update("feature payload", report_type="features"), + _server_update({"metrics": 1.0}, report_type="metrics"), + _server_update(weight_payload), + ] + + asyncio.run(server._process_reports()) + + assert server.algorithm.delta_payloads == [weight_payload] + assert strategy.last_updates == [server.updates[2]] + assert len(strategy.last_deltas) == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): + """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(99)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([15.0])) + + +def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): + """Outer SGD with lr=1 should match sample-weighted FedAvg.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(3)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([6.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([16.5])) + + +def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): + """A lower outer SGD lr should partially move toward the averaged model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(5), _update(5)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([2.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([12.5])) + + +def test_sgd_uses_diloco_outer_gradient_sign(temp_config): + """The strategy should negate Plato deltas before applying outer SGD.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.25, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([1.0])) + + +def test_outer_optimizer_application_is_logged(temp_config, caplog): + """A DiLoCo aggregation should report the server-side outer optimizer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=0.7, + outer_momentum=0.9, + aggregation_weighting="uniform", + apply_outer_optimizer_to="parameters", + ) + model = torch.nn.Linear(1, 1, bias=False) + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + with caplog.at_level(logging.INFO): + _aggregate( + strategy, + [_update(1), _update(1)], + [ + {"weight": torch.tensor([[2.0]])}, + {"weight": torch.tensor([[4.0]])}, + ], + baseline, + model, + ) + + message = caplog.text + assert "DiLoCo outer optimizer applied" in message + assert "optimizer=nesterov" in message + assert "outer_lr=0.7" in message + assert "outer_momentum=0.9" in message + assert "weighting=uniform" in message + assert "apply_to=parameters" in message + assert "eligible_updates=2" in message + assert "optimized_tensors=1" in message + + +def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): + """Uniform mode should weight eligible clients equally.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(1), _update(1000)], + [{"w": torch.tensor([0.0])}, {"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([5.0])) + + +def test_nonpositive_sample_reports_are_ineligible(temp_config): + """Reports with zero or negative sample counts should not affect averages.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server_delta = _aggregate( + strategy, + [_update(0), _update(-5), _update(10)], + [ + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([4.0])}, + ], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0])) + + +def test_empty_eligible_updates_return_zero_delta(temp_config): + """An empty eligible set should produce a zero delta matching the baseline.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + baseline = {"w": torch.tensor([3.0, 4.0])} + server_delta = _aggregate( + strategy, + [_update(0), _update(5, report_type="features")], + [{"w": torch.tensor([10.0, 10.0])}, {"w": torch.tensor([10.0, 10.0])}], + baseline, + ) + + assert torch.allclose(server_delta["w"], torch.zeros_like(baseline["w"])) + + +def test_empty_eligible_updates_remove_stale_momentum(temp_config): + """A round with no eligible keys should clear stale momentum buffers.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + server_delta = _aggregate( + strategy, + [_update(0)], + [{"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([0.0])) + assert strategy.momentum_state == {} + + +def test_sgdm_persists_momentum_across_rounds(temp_config): + """Momentum SGD should reuse server-side outer momentum across rounds.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([2.0])) + assert torch.allclose(second_delta["w"], torch.tensor([5.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): + """Nesterov should use g + beta * m after updating the momentum buffer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([3.0])) + assert torch.allclose(second_delta["w"], torch.tensor([6.5])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( + temp_config, +): + """Momentum state should reset incompatible keys and prune missing keys.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0]), "b": torch.tensor([1.0])}], + {"w": torch.tensor([0.0]), "b": torch.tensor([0.0])}, + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0, 6.0])}], + {"w": torch.tensor([0.0, 0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0, 6.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-4.0, -6.0])) + assert "b" not in strategy.momentum_state + + +def test_momentum_state_resets_on_dtype_mismatch(temp_config): + """Momentum state should reset when the tensor dtype changes.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0], dtype=torch.float32)}], + {"w": torch.tensor([0.0], dtype=torch.float32)}, + ) + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0], dtype=torch.float64)}], + {"w": torch.tensor([0.0], dtype=torch.float64)}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0], dtype=torch.float64)) + assert strategy.momentum_state["w"].dtype == torch.float64 + + +def test_parameters_policy_optimizes_only_trainable_floating_parameters( + temp_config, +): + """Default policy should leave frozen parameters and buffers on FedAvg deltas.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + first_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + model, + ) + + assert torch.allclose(first_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(first_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(first_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(first_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(first_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-4.0])) + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + model, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([6.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([6.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-8.0])) + + +def test_all_floating_policy_optimizes_every_floating_state_tensor(temp_config): + """All-floating mode should not require model context for eligibility.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + server_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + ) + + assert torch.allclose(server_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(server_delta["frozen"], torch.tensor([2.0])) + assert torch.allclose(server_delta["floating_buffer"], torch.tensor([2.0])) + assert torch.equal(server_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(server_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + +def test_parameters_policy_requires_trainer_model_context(temp_config): + """Default parameter eligibility should fail clearly without a model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + with pytest.raises(AttributeError, match="context.trainer.model"): + _aggregate( + strategy, + [_update(1)], + [{"trainable": torch.tensor([2.0])}], + {"trainable": torch.tensor([0.0])}, + ) + + +def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): + """PEFT payloads can omit adapter-name segments from trainable param names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = PeftLikeAdapterModel() + payload_name = "base_model.model.linear.lora_A.weight" + baseline = {payload_name: torch.zeros((1, 1))} + + server_delta = _aggregate( + strategy, + [_update(1)], + [{payload_name: torch.full((1, 1), 4.0)}], + baseline, + model, + ) + + assert torch.allclose(server_delta[payload_name], torch.full((1, 1), 2.0)) + assert set(strategy.momentum_state) == {payload_name} + assert torch.allclose( + strategy.momentum_state[payload_name], torch.full((1, 1), -4.0) + ) + + +def test_parameters_policy_does_not_overmatch_adapter_alias_collisions(temp_config): + """Alias support should not optimize unrelated colliding payload names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = AdapterAliasCollisionModel() + trainable_name = "foo.default.weight" + colliding_name = "foo.weight" + baseline = { + trainable_name: torch.zeros((1, 1)), + colliding_name: torch.zeros((1, 1)), + } + + server_delta = _aggregate( + strategy, + [_update(1)], + [ + { + trainable_name: torch.full((1, 1), 4.0), + colliding_name: torch.full((1, 1), 4.0), + } + ], + baseline, + model, + ) + + assert torch.allclose(server_delta[trainable_name], torch.full((1, 1), 2.0)) + assert torch.allclose(server_delta[colliding_name], torch.full((1, 1), 4.0)) + assert set(strategy.momentum_state) == {trainable_name} + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"outer_optimizer": "adam"}, "outer_optimizer"), + ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"apply_outer_optimizer_to": "buffers"}, "apply_outer_optimizer_to"), + ({"outer_learning_rate": -0.1}, "outer_learning_rate"), + ({"outer_momentum": -0.1}, "outer_momentum"), + ({"outer_momentum": 1.0}, "outer_momentum"), + ], +) +def test_invalid_config_values_fail_clearly(temp_config, kwargs, match): + """Invalid DiLoCo aggregation configuration should raise clear errors.""" + with pytest.raises(ValueError, match=match): + DiLoCoAggregationStrategy(**kwargs) diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py new file mode 100644 index 000000000..b005981f0 --- /dev/null +++ b/tests/trainers/test_composable_optimizer_state.py @@ -0,0 +1,644 @@ +"""Tests for in-process optimizer state preservation in ComposableTrainer.""" + +import copy +import os +import pickle +import sys +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from plato.config import Config +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import ( + AdamWOptimizerStrategy, + CrossEntropyLossStrategy, + DefaultTrainingStepStrategy, + SGDOptimizerStrategy, + StepLRSchedulerStrategy, +) +from plato.trainers.strategies.base import OptimizerStrategy, TrainingContext + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + + +@pytest.fixture +def tiny_dataset(): + features = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [-1.0, 0.5], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 1, 0, 1], dtype=torch.long) + return TensorDataset(features, labels) + + +@pytest.fixture +def one_step_config(): + return { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "optimizer-state-test", + } + + +class CapturingTrainingStep(DefaultTrainingStepStrategy): + """Record optimizer state before each local optimizer step.""" + + def __init__(self): + super().__init__() + self.pre_step_states = [] + self.pre_step_lrs = [] + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + optimizer_state = optimizer.state_dict() + self.pre_step_states.append(copy.deepcopy(optimizer_state["state"])) + self.pre_step_lrs.append( + [group["lr"] for group in optimizer_state["param_groups"]] + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + +def _linear_model(): + return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) + + +class DeviceTrackingModel(nn.Module): + """Model that records whether it has been moved to a trainer device.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.moved_to_trainer_device = False + + def forward(self, features): + return self.linear(features) + + def to(self, *args, **kwargs): + self.moved_to_trainer_device = True + return super().to(*args, **kwargs) + + +class RestoreOrderOptimizer(torch.optim.SGD): + """Optimizer that records whether state restore happens after model.to().""" + + def __init__(self, params, model: DeviceTrackingModel): + self.model = model + self.loaded_after_model_to = None + super().__init__(params, lr=0.01, momentum=0.9) + + def load_state_dict(self, state_dict): + self.loaded_after_model_to = self.model.moved_to_trainer_device + if not self.loaded_after_model_to: + raise AssertionError("optimizer state restored before model.to()") + return super().load_state_dict(state_dict) + + +class RestoreOrderOptimizerStrategy(OptimizerStrategy): + """Create restore-order-aware optimizers for regression tests.""" + + def __init__(self): + self.optimizers = [] + + def create_optimizer( + self, model: DeviceTrackingModel, context: TrainingContext + ) -> torch.optim.Optimizer: + optimizer = RestoreOrderOptimizer(model.parameters(), model) + self.optimizers.append(optimizer) + return optimizer + + +def _two_layer_model(first_name="first", second_name="second"): + return nn.Sequential( + OrderedDict( + [ + (first_name, nn.Linear(2, 2, bias=False)), + (second_name, nn.Linear(2, 2, bias=False)), + ] + ) + ) + + +def _first_param_state(optimizer_state): + return next(iter(optimizer_state.values())) + + +def _state_step(param_state): + step = param_state["step"] + if isinstance(step, torch.Tensor): + return int(step.item()) + return int(step) + + +def _configure_subprocess_training( + monkeypatch, + tmp_path, + *, + preserve_optimizer_state, +): + """Configure parent and spawned child processes to share local artifacts.""" + model_path = Path(tmp_path) / "models" / "pretrained" + model_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(Path(tmp_path) / "checkpoints") + Config.params["run_id"] = "subprocess-optimizer-state" + os.makedirs(Config.params["checkpoint_path"], exist_ok=True) + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + max_concurrency=1, + preserve_optimizer_state=preserve_optimizer_state, + batch_size=4, + epochs=1, + ) + + +def _cached_optimizer_step(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return _state_step(_first_param_state(payload["optimizer_state"]["state"])) + + +def _cached_scheduler_last_epoch(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return payload["scheduler_state"]["last_epoch"] + + +def _assert_model_update_contains_only_model_weights(update, model): + model_state = model.state_dict() + + assert set(update) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(update) + assert all(torch.is_tensor(value) for value in update.values()) + + +def test_adamw_moment_buffers_persist_between_rounds_for_same_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + trainer.set_client_id(7) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + round1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[1]) + saved_state = _first_param_state(round1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + assert torch.allclose(restored_state["exp_avg_sq"], saved_state["exp_avg_sq"]) + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 2 + + +def test_preserved_optimizer_state_restores_after_model_moves_to_device( + temp_config, tiny_dataset, one_step_config +): + config = {**one_step_config, "preserve_optimizer_state": True} + source_trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=RestoreOrderOptimizerStrategy(), + ) + source_trainer.set_client_id(11) + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + restore_strategy = RestoreOrderOptimizerStrategy() + trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=restore_strategy, + ) + trainer.set_client_id(11) + trainer._preserved_optimizer_states[11] = copy.deepcopy( + source_trainer._preserved_optimizer_states[11] + ) + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert restore_strategy.optimizers[0].loaded_after_model_to is True + restored_state = _first_param_state( + trainer._preserved_optimizer_states[11]["optimizer_state"]["state"] + ) + assert "momentum_buffer" in restored_state + + +def test_scheduler_state_and_lr_progress_persist_between_rounds( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + training_step_strategy=step_strategy, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_lrs == [[0.2], [0.1]] + assert trainer.lr_scheduler.last_epoch == 2 + assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_optimizer_state_parent_reloads_after_child( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer.client_id in trainer._preserved_optimizer_states + assert _cached_optimizer_step(trainer) == 1 + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert state_path.exists() + assert "optimizer_state" not in trainer.obtain_model_update( + { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "payload-check", + "preserve_optimizer_state": True, + }, + tiny_dataset, + list(range(len(tiny_dataset))), + ) + + +def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 2 + + +def test_subprocess_scheduler_state_persists_across_rounds( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert _cached_scheduler_last_epoch(trainer) == 2 + assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_missing_sidecar_clears_inherited_parent_cache( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + source_trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + source_trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + assert _cached_optimizer_step(source_trainer) == 1 + + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + trainer._preserved_optimizer_states[7] = copy.deepcopy( + source_trainer._preserved_optimizer_states[7] + ) + + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + state_path.unlink(missing_ok=True) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 1 + + +def test_missing_subprocess_output_removes_stale_input_sidecar( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + input_filename = trainer._optimizer_state_filename(Config.params["run_id"]) + missing_output_filename = trainer._optimizer_state_output_filename( + Config.params["run_id"] + ) + assert trainer._save_preserved_optimizer_state_file(input_filename) + input_path = Path(Config.params["model_path"]) / input_filename + assert input_path.exists() + + trainer._finish_subprocess_optimizer_state( + input_filename, missing_output_filename + ) + + assert trainer.client_id not in trainer._preserved_optimizer_states + assert not input_path.exists() + + +def test_subprocess_invalid_optimizer_state_resets_safely( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with open(state_path, "wb") as state_file: + pickle.dump({"optimizer_type": torch.optim.SGD}, state_file) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert payload["optimizer_type"] is torch.optim.AdamW + assert _cached_optimizer_step(trainer) == 1 + + +def test_subprocess_optimizer_state_is_not_persisted_when_disabled( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=False + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer._preserved_optimizer_states == {} + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert not state_path.exists() + + +def test_preserved_optimizer_state_is_local_to_logical_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + client1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + + trainer.set_client_id(2) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + assert step_strategy.pre_step_states[1] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[2]) + saved_state = _first_param_state(client1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + + +def test_preserved_state_stays_out_of_model_update_payload( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(5) + config = {**one_step_config, "preserve_optimizer_state": True} + + update = trainer.obtain_model_update( + config, tiny_dataset, list(range(len(tiny_dataset))) + ) + preserved_state = trainer._preserved_optimizer_states[trainer.client_id] + + assert preserved_state["optimizer_state"]["state"] + assert preserved_state["scheduler_state"]["last_epoch"] >= 1 + _assert_model_update_contains_only_model_weights(update, trainer.model) + + +def test_preserved_state_invalidates_when_parameter_order_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=lambda: _two_layer_model("first", "second"), + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.model = _two_layer_model("second", "first") + trainer.context.model = trainer.model + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + + +def test_preserved_state_invalidates_when_optimizer_type_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.optimizer_strategy = SGDOptimizerStrategy(lr=0.1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + assert isinstance(trainer.optimizer, torch.optim.SGD) + + +def test_preserved_state_compatibility_rejects_shape_dtype_and_scheduler_changes( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(4) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + payload = copy.deepcopy(trainer._preserved_optimizer_states[4]) + + current_model = trainer.model + current_optimizer = trainer.optimizer_strategy.create_optimizer( + current_model, trainer.context + ) + changed_scheduler = StepLRSchedulerStrategy( + step_size=1, gamma=0.5 + ).create_scheduler(current_optimizer, trainer.context) + assert not trainer._preserved_state_is_compatible( + payload, current_model, current_optimizer, changed_scheduler + ) + + changed_shape_model = nn.Sequential( + OrderedDict([("linear", nn.Linear(2, 3))]) + ) + changed_shape_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_shape_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_shape_model, changed_shape_optimizer, None + ) + + changed_dtype_model = _linear_model().to(torch.float64) + changed_dtype_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_dtype_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_dtype_model, changed_dtype_optimizer, None + ) + + +@pytest.mark.parametrize("preserve_value", [None, False]) +def test_optimizer_state_is_not_restored_when_disabled_or_unset( + temp_config, tiny_dataset, one_step_config, preserve_value +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = dict(one_step_config) + if preserve_value is not None: + config["preserve_optimizer_state"] = preserve_value + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states == [{}, {}] + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 1 diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ff35f76ec..fd6b4a668 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -5,11 +5,14 @@ it works correctly in end-to-end training scenarios. """ +import logging + import pytest import torch import torch.nn as nn -from torch.utils.data import TensorDataset +from torch.utils.data import SubsetRandomSampler, TensorDataset +from plato.callbacks.trainer import TrainerCallback from plato.config import Config from plato.evaluators.runner import EVALUATION_PRIMARY_KEY, EVALUATION_RESULTS_KEY from plato.trainers.composable import ComposableTrainer @@ -18,13 +21,16 @@ CrossEntropyLossStrategy, DefaultDataLoaderStrategy, DefaultTrainingStepStrategy, + GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, + StepLRSchedulerStrategy, ) from plato.trainers.strategies.base import ( LossCriterionStrategy, ModelUpdateStrategy, TrainingContext, + TrainingStepStrategy, ) @@ -178,6 +184,452 @@ def test_multiple_epochs(self, simple_model, simple_dataset): assert len(trainer.run_history.get_metric_values("train_loss")) == 5 +class TestComposableTrainerLocalSteps: + """Test local optimizer-step limits for DiLoCo-style local work.""" + + class DeterministicPlatoSampler: + def __init__(self, indices, seed=47): + self.indices = list(indices) + self.seed = seed + + def get(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + return SubsetRandomSampler(self.indices, generator=generator) + + def num_samples(self): + return len(self.indices) + + class NonMaterializableSampler(torch.utils.data.Sampler): + def __iter__(self): + raise NotImplementedError("This sampler cannot be materialized.") + + def __len__(self): + return 10 + + class CountingCallback(TrainerCallback): + def __init__(self): + self.train_run_end_called = False + self.train_step_end_count = 0 + + def on_train_step_end(self, trainer, config, batch, loss, **kwargs): + self.train_step_end_count += 1 + + def on_train_run_end(self, trainer, config, **kwargs): + self.train_run_end_called = True + + class CountingUpdateStrategy(ModelUpdateStrategy): + def __init__(self): + self.after_step_count = 0 + self.on_train_end_called = False + + def after_step(self, context): + self.after_step_count += 1 + + def on_train_end(self, context): + self.on_train_end_called = True + + class CountingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class RecordingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.samples_by_round = {} + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + sample_ids = examples[:, 0].detach().cpu().int().tolist() + self.samples_by_round.setdefault(context.current_round, []).extend( + sample_ids + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class DelayedOptimizerStepStrategy(TrainingStepStrategy): + def __init__(self, accumulation_steps=2, finalize_steps=True): + self.accumulation_steps = accumulation_steps + self.finalize_steps = finalize_steps + self.raw_batch_count = 0 + self.optimizer_step_count = 0 + self.finalize_calls = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + outputs = model(examples) + loss = loss_criterion(outputs, labels) + (loss / self.accumulation_steps).backward() + + self.raw_batch_count += 1 + if self.raw_batch_count % self.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False + + return loss + + def finalize(self, model, optimizer, context): + self.finalize_calls += 1 + if not self.finalize_steps: + return None + + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + return torch.tensor(0.0) + + class CountingGradientAccumulationStepStrategy(GradientAccumulationStepStrategy): + def __init__(self, accumulation_steps): + super().__init__(accumulation_steps=accumulation_steps) + self.raw_batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.raw_batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + def test_local_steps_stop_mid_epoch_and_run_cleanup( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + callback = self.CountingCallback() + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + callbacks=[callback], + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 3 + assert update_strategy.after_step_count == 3 + assert callback.train_step_end_count == 3 + assert trainer.current_epoch == 1 + assert trainer.context.state["local_optimizer_steps"] == 3 + assert update_strategy.on_train_end_called + assert callback.train_run_end_called + + def test_local_steps_count_optimizer_steps_not_raw_batches( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=3) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert step_strategy.optimizer_step_count == 2 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_respect_builtin_gradient_accumulation( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingGradientAccumulationStepStrategy( + accumulation_steps=3 + ) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_skip_finalize_after_limit_is_reached( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 2, + "local_steps_per_round": 1, + } + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=2) + update_strategy = self.CountingUpdateStrategy() + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 2 + assert step_strategy.optimizer_step_count == 1 + assert step_strategy.finalize_calls == 0 + assert update_strategy.after_step_count == 1 + assert trainer.context.state["local_optimizer_steps"] == 1 + + def test_epoch_behavior_is_unchanged_when_local_steps_unset( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 10, + "epochs": 2, + } + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 20 + assert trainer.current_epoch == 2 + assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + + def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( + self, simple_model, simple_config + ): + dataset_size = 10 + features = torch.arange(dataset_size, dtype=torch.float32).view(-1, 1) + features = features.repeat(1, 10) + labels = torch.arange(dataset_size) % 2 + dataset = TensorDataset(features, labels) + config = { + **simple_config, + "batch_size": 1, + "epochs": 1, + "local_steps_per_round": 3, + } + sampler = self.DeterministicPlatoSampler(range(dataset_size)) + step_strategy = self.RecordingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + trainer.set_client_id(2) + + for round_number in (1, 2): + trainer.current_round = round_number + trainer.train_model(config, dataset, sampler) + + round_one_samples = step_strategy.samples_by_round[1] + round_two_samples = step_strategy.samples_by_round[2] + + assert len(round_one_samples) == config["local_steps_per_round"] + assert len(round_two_samples) == config["local_steps_per_round"] + assert round_one_samples != round_two_samples + + repeat_step_strategy = self.RecordingStepStrategy() + repeat_trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=repeat_step_strategy, + ) + repeat_trainer.set_client_id(2) + + for round_number in (1, 2): + repeat_trainer.current_round = round_number + repeat_trainer.train_model(config, dataset, sampler) + + assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + + def test_local_step_sampling_warns_for_non_materializable_sampler( + self, simple_dataset, caplog + ): + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + sampler = self.NonMaterializableSampler() + + with caplog.at_level(logging.WARNING): + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + sampler, + batch_size=1, + context=context, + ) + + assert loader.sampler is sampler + assert ( + "cannot be materialized for round-aware local-step sampling" + in caplog.text + ) + + def test_diloco_local_steps_require_full_client_participation( + self, simple_dataset, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + with pytest.raises( + ValueError, match="clients\\.per_round.*clients\\.total_clients" + ): + DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + def test_partial_participation_still_allowed_without_diloco_local_steps( + self, simple_dataset, temp_config + ): + Config().server.type = "fedavg" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + assert len(loader.sampler) == len(simple_dataset) + + def test_diloco_local_steps_advance_lr_scheduler_per_optimizer_step( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 4 + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 3 + + def test_non_diloco_local_steps_keep_epoch_based_lr_scheduler( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "fedavg" + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 1 + + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) + def test_invalid_local_steps_fail_clearly( + self, simple_model, simple_dataset, simple_config, local_steps_per_round + ): + config = { + **simple_config, + "local_steps_per_round": local_steps_per_round, + } + trainer = ComposableTrainer(model=simple_model) + + with pytest.raises(ValueError, match="local_steps_per_round"): + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + class TestComposableTrainerStrategies: """Test strategy integration."""