diff --git a/examples/defense/MDP.py b/examples/defense/MDP.py new file mode 100644 index 00000000..c997f081 --- /dev/null +++ b/examples/defense/MDP.py @@ -0,0 +1,100 @@ +""" +MDP Defense Example + +Demonstrates the Matrix Decomposition + Differential Privacy defense +for protecting GNN models against privacy attacks. + +Usage: + python examples/defense/MDP.py +""" + +from pygip.datasets import Cora, CiteSeer, PubMed +from pygip.models.defense import MDP + + +def main(): + # Load dataset with PyG format (required for MDP) + print("Loading Cora dataset...") + dataset = Cora(api_type='pyg') + + print(f"Dataset: {dataset.dataset_name}") + print(f"Nodes: {dataset.num_nodes}") + print(f"Features: {dataset.num_features}") + print(f"Classes: {dataset.num_classes}") + + # Initialize MDP defense + print("\nInitializing MDP defense...") + defense = MDP( + dataset=dataset, + attack_node_fraction=0.1, + # MDP parameters + nc=4, # Number of federated calculators + es=2, # Number of eigenvalue shares + epsilon=30.0, # DP privacy budget (higher = less noise) + keep_ratio=0.8, # Training data fraction per calculator + # Model architecture + hidden_dim=16, + dropout=0.5, + # Training + lr=0.01, + weight_decay=5e-4, + epochs=200, + patience=50, + seed=42, + ) + + # Execute defense + print("\nExecuting MDP defense...") + results = defense.defend() + + # Print results + print("\n" + "=" * 50) + print("DEFENSE RESULTS") + print("=" * 50) + + for key, value in results.items(): + if isinstance(value, float): + print(f" {key}: {value:.4f}") + else: + print(f" {key}: {value}") + + # Access internal state if needed + print("\nDefense Details:") + print(f" Adjacency shares: {len(defense.get_adjacency_shares())}") + stats = defense.get_training_stats() + if stats: + print(f" Training epochs: {stats['epochs_trained']}") + print(f" Final val accuracy: {stats['val_acc'][-1]:.4f}") + + +def run_epsilon_sweep(): + """Sweep over different privacy budgets.""" + print("\n" + "=" * 50) + print("EPSILON SWEEP") + print("=" * 50) + + dataset = Cora(api_type='pyg') + + epsilons = [10.0, 20.0, 30.0, float('inf')] + + for eps in epsilons: + defense = MDP( + dataset=dataset, + nc=4, + es=2, + epsilon=eps, + epochs=100, + patience=30, + seed=42, + ) + + results = defense.defend() + acc = results.get('test_acc', 0) + + eps_str = "inf" if eps == float('inf') else f"{eps:.1f}" + print(f"epsilon={eps_str:>6}: test_acc={acc:.4f}") + + +if __name__ == "__main__": + main() + run_epsilon_sweep() diff --git a/pygip/models/defense/MDP.py b/pygip/models/defense/MDP.py new file mode 100644 index 00000000..66cb5fe8 --- /dev/null +++ b/pygip/models/defense/MDP.py @@ -0,0 +1,448 @@ +""" +MDP (Matrix Decomposition + Differential Privacy) Defense + +A privacy-preserving defense mechanism for Graph Neural Networks that: +1. Builds normalized adjacency matrix: Abar = I + D^(-1/2)AD^(-1/2) +2. Splits Abar into nc shares via eigendecomposition +3. Applies Laplace noise to features for differential privacy +4. Trains multiple "calculators" on different shares with federated averaging + +Reference: + Privacy-Preserving GNN Based on Matrix Decomposition and Differential Privacy +""" + +import time +from typing import Dict, List, Optional, Tuple, Any, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pygip.models.defense.base import BaseDefense +from pygip.utils.metrics import DefenseMetric, DefenseCompMetric +from pygip.utils.mdp.abar import build_abar_dense +from pygip.utils.mdp.eigenvalue_split import es_split_into_nc +from pygip.utils.mdp.dp_features import dp_features_laplace +from pygip.utils.mdp.splits import make_overlapping_train_masks +from pygip.models.nn.mdp_gcn import ManualGCN + + +class MDP(BaseDefense): + """ + Matrix Decomposition + Differential Privacy defense for GNNs. + + This defense provides privacy preservation by: + - Splitting the adjacency matrix into multiple shares via eigendecomposition + - Adding calibrated Laplace noise to node features + - Training multiple local models on different shares with parameter averaging + + Attributes: + supported_api_types: Set of compatible API types (pyg only due to dense matrices) + supported_datasets: Set of compatible dataset names (empty = all supported) + """ + + supported_api_types = {"pyg"} + supported_datasets = set() + + def __init__( + self, + dataset, + attack_node_fraction: float = 0.1, + device: Optional[Union[str, torch.device]] = None, + nc: int = 4, + es: int = 2, + epsilon: float = 30.0, + keep_ratio: float = 1.0, + hidden_dim: int = 16, + dropout: float = 0.5, + lr: float = 0.01, + weight_decay: float = 5e-4, + epochs: int = 200, + patience: int = 50, + seed: int = 42, + ): + """ + Initialize MDP defense. + + Args: + dataset: PyGIP Dataset instance (must have api_type='pyg') + attack_node_fraction: Fraction of nodes considered under attack + device: Torch device (auto-detected if None) + nc: Number of calculators for federated training + es: Number of eigenvalue shares (must be >= 2) + epsilon: Differential privacy budget (float('inf') for no noise) + keep_ratio: Fraction of training nodes each calculator sees (0,1] + hidden_dim: Hidden layer dimension for GCN + dropout: Dropout probability + lr: Learning rate + weight_decay: L2 regularization weight + epochs: Maximum training epochs + patience: Early stopping patience + seed: Random seed for reproducibility + """ + super().__init__(dataset, attack_node_fraction, device) + + if nc < 2: + raise ValueError("nc (number of calculators) must be >= 2") + if es < 2: + raise ValueError("es (number of eigenvalue shares) must be >= 2") + if not (0.0 < keep_ratio <= 1.0): + raise ValueError("keep_ratio must be in (0, 1]") + if epsilon <= 0 and epsilon != float('inf'): + raise ValueError("epsilon must be > 0 or float('inf')") + + self.nc = nc + self.es = es + self.epsilon = epsilon + self.keep_ratio = keep_ratio + self.seed = seed + + self.hidden_dim = hidden_dim + self.dropout = dropout + + self.lr = lr + self.weight_decay = weight_decay + self.epochs = epochs + self.patience = patience + + self.defense_model: Optional[nn.Module] = None + self.Abar: Optional[torch.Tensor] = None + self.Abar_shares: Optional[List[torch.Tensor]] = None + self.X_noised: Optional[torch.Tensor] = None + self._training_stats: Optional[Dict] = None + + def defend(self): + """ + Execute the MDP defense. + + Returns: + Tuple of (res, res_comp) where: + - res: Dictionary from DefenseMetric.compute() + - res_comp: Dictionary from DefenseCompMetric.compute() + """ + metric_comp = DefenseCompMetric() + metric_comp.start() + print("====================MDP Defense====================") + + # Build adjacency and apply DP + self._build_adjacency() + self._apply_dp_noise() + self._split_adjacency() + + # Train defense model + defense_s = time.time() + self.defense_model = self._train_defense_model() + defense_e = time.time() + metric_comp.update(defense_time=(defense_e - defense_s)) + + # Evaluate + inference_s = time.time() + preds, labels = self._get_predictions() + inference_e = time.time() + + # Compute metrics + metric = DefenseMetric() + metric.update(preds, labels) + metric_comp.end() + + print("====================Final Results====================") + res = metric.compute() + metric_comp.update(inference_defense_time=(inference_e - inference_s)) + res_comp = metric_comp.compute() + + return res, res_comp + + def _build_adjacency(self) -> None: + """Build normalized adjacency matrix Abar = I + D^(-1/2)AD^(-1/2).""" + data = self.graph_data + edge_index = data.edge_index.to(self.device) + num_nodes = self.num_nodes + + abar_result = build_abar_dense( + edge_index=edge_index, + num_nodes=num_nodes, + device=self.device + ) + self.Abar = abar_result.Abar.to(torch.float32) + + def _apply_dp_noise(self) -> None: + """Apply Laplace noise to features for differential privacy.""" + data = self.graph_data + X = data.x.to(self.device).to(torch.float32) + + dp_result = dp_features_laplace( + X, + epsilon=self.epsilon, + delta=1.0, + clip_min=0.0, + clip_max=1.0, + seed=self.seed + ) + self.X_noised = dp_result.X_dp + + def _split_adjacency(self) -> None: + """Split adjacency matrix into shares via eigendecomposition.""" + es_result = es_split_into_nc( + self.Abar, + nc=self.es, + seed=self.seed, + assume_symmetric=True + ) + + self.Abar_shares = [ + es_result.shares[i % self.es].to(torch.float32) + for i in range(self.nc) + ] + + def _train_target_model(self): + """ + Train the target model (baseline without defense). + + Returns: + torch.nn.Module: The trained target model + """ + print("Training target model...") + + in_dim = self.num_features + out_dim = self.num_classes + + model = ManualGCN( + in_dim=in_dim, + hidden_dim=self.hidden_dim, + out_dim=out_dim, + dropout=self.dropout + ).to(self.device) + + optimizer = torch.optim.Adam( + model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay + ) + + data = self.graph_data + X = data.x.to(self.device).to(torch.float32) + y = data.y.to(self.device) + train_mask = data.train_mask.to(self.device) + val_mask = data.val_mask.to(self.device) + + best_val = -1.0 + best_state = None + + for epoch in range(1, self.epochs + 1): + model.train() + logits = model(self.Abar, X) + loss = F.cross_entropy(logits[train_mask], y[train_mask]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + with torch.no_grad(): + logits = model(self.Abar, X) + val_acc = self._accuracy(logits, y, val_mask) + + if val_acc > best_val: + best_val = val_acc + best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} + + if best_state: + model.load_state_dict(best_state, strict=True) + + print(f"Target model trained. Val accuracy: {best_val:.4f}") + return model + + def _train_defense_model(self): + """ + Train the defense model using federated averaging across calculators. + + Returns: + torch.nn.Module: The trained defense model + """ + print("Training defense model with MDP...") + + data = self.graph_data + + train_masks_per_calc = None + if self.nc > self.es: + train_masks_per_calc = make_overlapping_train_masks( + data.train_mask.to(self.device), + nc=self.nc, + seed=self.seed, + keep_ratio=self.keep_ratio + ) + + in_dim = self.num_features + out_dim = self.num_classes + + def model_ctor(): + return ManualGCN( + in_dim=in_dim, + hidden_dim=self.hidden_dim, + out_dim=out_dim, + dropout=self.dropout + ) + + stats, best_state = self._federated_train( + model_ctor=model_ctor, + Abar_full=self.Abar, + Abar_shares=self.Abar_shares, + X=self.X_noised, + y=data.y.to(self.device), + train_mask=data.train_mask.to(self.device), + val_mask=data.val_mask.to(self.device), + train_masks_per_calc=train_masks_per_calc, + ) + + self._training_stats = stats + + model = model_ctor().to(self.device) + model.load_state_dict(best_state, strict=True) + + print(f"Defense model trained. Val accuracy: {stats['val_acc'][-1]:.4f}") + return model + + def _train_surrogate_model(self): + """ + Train surrogate model (for attack evaluation). + + Returns: + torch.nn.Module: The trained surrogate model + """ + return self._train_target_model() + + def _federated_train( + self, + model_ctor, + Abar_full: torch.Tensor, + Abar_shares: List[torch.Tensor], + X: torch.Tensor, + y: torch.Tensor, + train_mask: torch.Tensor, + val_mask: torch.Tensor, + train_masks_per_calc: Optional[List[torch.Tensor]] = None, + ) -> Tuple[Dict, Dict[str, torch.Tensor]]: + """ + Federated training loop with parameter averaging. + + Each calculator trains on its assigned adjacency share. + After each epoch, parameters are averaged across all calculators. + """ + nc = len(Abar_shares) + + local_models = [model_ctor().to(self.device) for _ in range(nc)] + optimizers = [ + torch.optim.Adam(m.parameters(), lr=self.lr, weight_decay=self.weight_decay) + for m in local_models + ] + + train_loss_hist = [] + val_acc_hist = [] + + best_val = -1.0 + best_state = None + epochs_since_improve = 0 + + for epoch in range(1, self.epochs + 1): + local_states = [] + local_losses = [] + + for i in range(nc): + model = local_models[i] + opt = optimizers[i] + model.train() + + A_share = Abar_shares[i].to(self.device) + logits = model(A_share, X) + + tm = train_masks_per_calc[i] if train_masks_per_calc else train_mask + loss = F.cross_entropy(logits[tm], y[tm]) + + opt.zero_grad() + loss.backward() + opt.step() + + local_losses.append(loss.item()) + local_states.append({k: v.detach().cpu() for k, v in model.state_dict().items()}) + + avg_state = self._average_state_dicts(local_states) + for model in local_models: + model.load_state_dict(avg_state, strict=True) + + model_eval = local_models[0] + model_eval.eval() + with torch.no_grad(): + logits = model_eval(Abar_full.to(self.device), X) + val_acc = self._accuracy(logits, y, val_mask) + + train_loss_hist.append(sum(local_losses) / len(local_losses)) + val_acc_hist.append(val_acc) + + if val_acc > best_val: + best_val = val_acc + best_state = {k: v.detach().cpu() for k, v in model_eval.state_dict().items()} + epochs_since_improve = 0 + else: + epochs_since_improve += 1 + + if epochs_since_improve >= self.patience: + break + + stats = { + "train_loss": train_loss_hist, + "val_acc": val_acc_hist, + "epochs_trained": len(train_loss_hist) + } + + return stats, best_state + + def _average_state_dicts( + self, + state_dicts: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + """Average parameters across multiple state dicts.""" + avg = {} + for key in state_dicts[0].keys(): + stacked = torch.stack([sd[key].float() for sd in state_dicts], dim=0) + avg[key] = stacked.mean(dim=0) + return avg + + def _accuracy( + self, + logits: torch.Tensor, + y: torch.Tensor, + mask: torch.Tensor + ) -> float: + """Compute accuracy for masked nodes.""" + pred = logits.argmax(dim=1) + return float((pred[mask] == y[mask]).float().mean().item()) + + def _get_predictions(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Get predictions and labels for test set.""" + data = self.graph_data + test_mask = data.test_mask.to(self.device) + y = data.y.to(self.device) + + self.defense_model.eval() + with torch.no_grad(): + logits = self.defense_model(self.Abar.to(self.device), self.X_noised) + preds = logits.argmax(dim=1)[test_mask] + labels = y[test_mask] + + return preds.cpu(), labels.cpu() + + def _load_model(self): + """Load pre-trained model (not implemented for MDP).""" + pass + + def get_defended_features(self) -> torch.Tensor: + """Return the DP-noised features.""" + return self.X_noised + + def get_adjacency_shares(self) -> List[torch.Tensor]: + """Return the eigenvalue-split adjacency shares.""" + return self.Abar_shares + + def get_training_stats(self) -> Optional[Dict]: + """Return training statistics.""" + return self._training_stats diff --git a/pygip/models/defense/__init__.py b/pygip/models/defense/__init__.py index 057c1c3e..cfb27193 100644 --- a/pygip/models/defense/__init__.py +++ b/pygip/models/defense/__init__.py @@ -8,6 +8,7 @@ from .Integrity import QueryBasedVerificationDefense as IntegrityVerification from .GrOVe import GroveDefense from .Revisiting import Revisiting +from .MDP import MDP __all__ = [ 'BackdoorWM', @@ -16,8 +17,9 @@ 'RandomWM', 'SurviveWM', 'SurviveWM2', - 'IntegrityVerification' - 'GroveDefense' + 'IntegrityVerification', + 'GroveDefense', 'ATOM', - 'Revisiting' + 'Revisiting', + 'MDP', ] diff --git a/pygip/models/nn/__init__.py b/pygip/models/nn/__init__.py index d40027c9..f6993aa9 100644 --- a/pygip/models/nn/__init__.py +++ b/pygip/models/nn/__init__.py @@ -1 +1,2 @@ from .backbones import GCN, GraphSAGE, ShadowNet, AttackNet +from .mdp_gcn import ManualGCN diff --git a/pygip/models/nn/mdp_gcn.py b/pygip/models/nn/mdp_gcn.py new file mode 100644 index 00000000..f40fd3ea --- /dev/null +++ b/pygip/models/nn/mdp_gcn.py @@ -0,0 +1,99 @@ +""" +ManualGCN: A GCN implementation that accepts dense adjacency matrices. + +Unlike standard PyG/DGL GCN layers that expect edge_index or DGLGraph, +this model performs message passing using dense matrix multiplication. +This is required for MDP defense which operates on eigendecomposed +adjacency matrix shares. +""" + +from __future__ import annotations + +from typing import List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +AdjType = Union[torch.Tensor, List[torch.Tensor]] + + +class ManualGCN(nn.Module): + """ + Two-layer GCN using dense adjacency matrix multiplication. + + Message passing is computed as: H' = A @ H @ W + where A can be a single matrix or list of shares that sum to the full matrix. + + Args: + in_dim: Input feature dimension + hidden_dim: Hidden layer dimension + out_dim: Output dimension (number of classes) + dropout: Dropout probability + """ + + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + dropout: float = 0.5 + ): + super().__init__() + self.lin1 = nn.Linear(in_dim, hidden_dim, bias=True) + self.lin2 = nn.Linear(hidden_dim, out_dim, bias=True) + self.dropout_p = float(dropout) + + def _message_passing( + self, + A_or_shares: AdjType, + H: torch.Tensor + ) -> torch.Tensor: + """ + Perform message passing via dense matrix multiplication. + + If A_or_shares is a list, sums the results of A_i @ H for each share. + This allows training on individual shares while inference uses the full sum. + """ + if isinstance(A_or_shares, list): + out = None + for A in A_or_shares: + A = A.to(dtype=torch.float32, device=H.device) + part = A @ H + out = part if out is None else (out + part) + return out + A = A_or_shares.to(dtype=torch.float32, device=H.device) + return A @ H + + def forward( + self, + A_or_shares: AdjType, + X: torch.Tensor, + dropout: float = None + ) -> torch.Tensor: + """ + Forward pass through 2-layer GCN. + + Args: + A_or_shares: Dense adjacency matrix or list of shares + X: Node feature matrix [N, in_dim] + dropout: Override dropout probability (optional) + + Returns: + Logits tensor [N, out_dim] + """ + X = X.to(dtype=torch.float32) + p = self.dropout_p if dropout is None else float(dropout) + + # Layer 1: Message passing -> Linear -> ReLU -> Dropout + H = self._message_passing(A_or_shares, X) + H = self.lin1(H) + H = F.relu(H) + H = F.dropout(H, p=p, training=self.training) + + # Layer 2: Message passing -> Linear + H = self._message_passing(A_or_shares, H) + H = self.lin2(H) + + return H diff --git a/pygip/utils/mdp/__init__.py b/pygip/utils/mdp/__init__.py new file mode 100644 index 00000000..9f5f738a --- /dev/null +++ b/pygip/utils/mdp/__init__.py @@ -0,0 +1,16 @@ +"""MDP (Matrix Decomposition + Differential Privacy) utilities.""" + +from .abar import build_abar_dense, AbarResult +from .eigenvalue_split import es_split_into_nc, ESNCResult +from .dp_features import dp_features_laplace, DPResult +from .splits import make_overlapping_train_masks + +__all__ = [ + "build_abar_dense", + "AbarResult", + "es_split_into_nc", + "ESNCResult", + "dp_features_laplace", + "DPResult", + "make_overlapping_train_masks", +] diff --git a/pygip/utils/mdp/abar.py b/pygip/utils/mdp/abar.py new file mode 100644 index 00000000..2a13721f --- /dev/null +++ b/pygip/utils/mdp/abar.py @@ -0,0 +1,58 @@ +"""Build normalized adjacency matrix for MDP defense.""" + +from __future__ import annotations + +from dataclasses import dataclass +import time +import torch + + +@dataclass +class AbarResult: + """Result of normalized adjacency matrix construction.""" + Abar: torch.Tensor + build_time_sec: float + + +def build_abar_dense( + *, + edge_index: torch.Tensor, + num_nodes: int, + device: torch.device +) -> AbarResult: + """ + Build normalized dense adjacency matrix. + + Computes: Abar = I + D^(-1/2) A D^(-1/2) + + Args: + edge_index: Edge index tensor [2, E] + num_nodes: Number of nodes + device: Torch device + + Returns: + AbarResult containing the normalized adjacency matrix + """ + t0 = time.time() + + n = int(num_nodes) + if edge_index.dim() != 2 or edge_index.size(0) != 2: + raise ValueError("edge_index must have shape [2, E]") + + row = edge_index[0].to(device) + col = edge_index[1].to(device) + + A = torch.zeros((n, n), device=device, dtype=torch.float32) + A[row, col] = 1.0 + A[col, row] = 1.0 + A.fill_diagonal_(0.0) + + deg = A.sum(dim=1) + inv_sqrt = torch.zeros_like(deg) + nz = deg > 0 + inv_sqrt[nz] = deg[nz].pow(-0.5) + + A_norm = inv_sqrt.view(n, 1) * A * inv_sqrt.view(1, n) + Abar = torch.eye(n, device=device, dtype=torch.float32) + A_norm + + return AbarResult(Abar=Abar, build_time_sec=time.time() - t0) diff --git a/pygip/utils/mdp/dp_features.py b/pygip/utils/mdp/dp_features.py new file mode 100644 index 00000000..b11810fa --- /dev/null +++ b/pygip/utils/mdp/dp_features.py @@ -0,0 +1,76 @@ +"""Differential privacy for node features using Laplace mechanism.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional +import torch + + +@dataclass +class DPResult: + """Result of differential privacy application.""" + X_dp: torch.Tensor + epsilon: float + delta: float + scale: float + + +def sample_laplace( + shape, + *, + scale: float, + device: torch.device, + dtype: torch.dtype +) -> torch.Tensor: + """Sample from Laplace distribution.""" + U = torch.rand(shape, device=device, dtype=dtype) - 0.5 + noise = -scale * torch.sign(U) * torch.log1p(-2.0 * torch.abs(U)) + return noise + + +def dp_features_laplace( + X: torch.Tensor, + *, + epsilon: float, + delta: float = 1.0, + clip_min: float = 0.0, + clip_max: float = 1.0, + seed: Optional[int] = None, +) -> DPResult: + """ + Apply differential privacy to features using Laplace mechanism. + + Args: + X: Feature matrix [N, F] + epsilon: Privacy budget (larger = less noise, inf = no noise) + delta: Sensitivity parameter + clip_min: Minimum feature value for clipping + clip_max: Maximum feature value for clipping + seed: Random seed for reproducibility + + Returns: + DPResult with DP-protected features + """ + if epsilon == float("inf"): + X_clipped = torch.clamp(X, min=clip_min, max=clip_max) + return DPResult(X_dp=X_clipped, epsilon=epsilon, delta=delta, scale=0.0) + + if epsilon <= 0: + raise ValueError("epsilon must be > 0 (or inf).") + + if seed is not None: + torch.manual_seed(seed) + + scale = float(delta) / float(epsilon) + X_clipped = torch.clamp(X, min=clip_min, max=clip_max) + noise = sample_laplace( + X_clipped.shape, + scale=scale, + device=X_clipped.device, + dtype=X_clipped.dtype + ) + X_noisy = X_clipped + noise + X_dp = torch.clamp(X_noisy, min=clip_min, max=clip_max) + + return DPResult(X_dp=X_dp, epsilon=epsilon, delta=delta, scale=scale) diff --git a/pygip/utils/mdp/eigenvalue_split.py b/pygip/utils/mdp/eigenvalue_split.py new file mode 100644 index 00000000..6c9a1147 --- /dev/null +++ b/pygip/utils/mdp/eigenvalue_split.py @@ -0,0 +1,74 @@ +"""Eigenvalue-based matrix splitting for MDP defense.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional +import torch + + +@dataclass +class ESNCResult: + """Result of eigenvalue splitting.""" + shares: List[torch.Tensor] + eigvals: torch.Tensor + assignments: torch.Tensor + + +def es_split_into_nc( + A: torch.Tensor, + *, + nc: int, + seed: Optional[int] = None, + assume_symmetric: bool = True, +) -> ESNCResult: + """ + Split matrix A into nc shares via eigendecomposition. + + Each eigenvalue is randomly assigned to one of nc shares. + The shares sum to the original matrix: A = sum(shares) + + Args: + A: Square matrix to split [n, n] + nc: Number of shares to create + seed: Random seed for reproducible assignments + assume_symmetric: If True, uses efficient symmetric eigendecomposition + + Returns: + ESNCResult with list of matrix shares + """ + if A.dim() != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f"A must be square (n,n). Got shape {tuple(A.shape)}") + + if nc < 2: + raise ValueError("nc must be >= 2") + + if seed is not None: + torch.manual_seed(seed) + + device = A.device + n = A.shape[0] + + if assume_symmetric: + eigvals, U = torch.linalg.eigh(A) + UT = U.t() + else: + eigvals_complex, U_complex = torch.linalg.eig(A) + eigvals = eigvals_complex.real + U = U_complex.real + UT = torch.linalg.pinv(U) + + assignments = torch.randint(low=0, high=nc, size=(n,), device=device) + + shares: List[torch.Tensor] = [] + for k in range(nc): + mask = (assignments == k) + lam_k = torch.zeros_like(eigvals) + lam_k[mask] = eigvals[mask] + A_k = U @ torch.diag(lam_k) @ UT + shares.append(A_k) + + shares = [s.to(dtype=torch.float32) for s in shares] + eigvals = eigvals.to(dtype=torch.float32) + + return ESNCResult(shares=shares, eigvals=eigvals, assignments=assignments) diff --git a/pygip/utils/mdp/splits.py b/pygip/utils/mdp/splits.py new file mode 100644 index 00000000..4c4a4a20 --- /dev/null +++ b/pygip/utils/mdp/splits.py @@ -0,0 +1,63 @@ +"""Generate overlapping training masks for federated calculators.""" + +from __future__ import annotations + +from typing import List +import torch + + +def make_overlapping_train_masks( + train_mask: torch.Tensor, + nc: int, + seed: int, + keep_ratio: float = 0.8 +) -> List[torch.Tensor]: + """ + Create overlapping training masks for nc calculators. + + Each calculator sees keep_ratio fraction of the training nodes, + with different random subsets to ensure diversity. + + Args: + train_mask: Boolean mask of training nodes + nc: Number of calculators + seed: Random seed + keep_ratio: Fraction of training nodes each calculator sees + + Returns: + List of nc training masks + """ + if train_mask.dtype != torch.bool: + raise ValueError("train_mask must be a boolean tensor") + if nc < 1: + raise ValueError("nc must be >= 1") + if not (0.0 < keep_ratio <= 1.0): + raise ValueError("keep_ratio must be in (0, 1]") + + num_nodes = train_mask.numel() + device = train_mask.device + train_idx = torch.where(train_mask)[0] + n_train = train_idx.numel() + keep_n = max(1, int(round(n_train * keep_ratio))) + + g = torch.Generator(device="cpu") + g.manual_seed(seed) + + masks: List[torch.Tensor] = [] + used_signatures = set() + + for _i in range(nc): + for _attempt in range(50): + perm = torch.randperm(n_train, generator=g) + chosen = train_idx[perm[:keep_n]].to(device) + m = torch.zeros(num_nodes, dtype=torch.bool, device=device) + m[chosen] = True + sig = tuple(sorted(chosen.tolist())) + if sig not in used_signatures: + used_signatures.add(sig) + masks.append(m) + break + else: + masks.append(m) + + return masks