diff --git a/README.md b/README.md index 509cce6..7dc518c 100644 --- a/README.md +++ b/README.md @@ -81,14 +81,14 @@ pip install -r requirements.txt ### Training ```bash -python train.py --config configs/train_joint.yaml +python train.py --config configs/train.yaml ``` ### Multi-GPU Training (DDP) ```bash python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=8 \ - train.py --config configs/train_joint.yaml + train.py --config configs/train.yaml ``` ### Inference diff --git a/configs/README.md b/configs/README.md index 5027e3f..7543574 100644 --- a/configs/README.md +++ b/configs/README.md @@ -4,11 +4,15 @@ ## 파일 목록 -- `train_joint.yaml` - - **학습 설정** (joint protein-ligand graph 아키텍처) - - Slurm 스크립트 `scripts/slurm/run_train_joint*.sh`에서 기본으로 사용합니다. +- `train.yaml` + - **Cartesian 학습 설정** (per-atom velocity field) + - Slurm 스크립트 `scripts/slurm/run_train_full.sh`에서 기본으로 사용합니다. - 추론 시에도 이 설정을 기반으로 모델을 로드합니다. +- `train_torsion.yaml` + - **SE(3) + Torsion 학습 설정** (translation + rotation + torsion) + - `train_torsion.py`에서 사용합니다. + ## 공통 구조(개요) 설정 파일은 아래 섹션을 갖습니다. @@ -34,21 +38,21 @@ ### 학습 (로컬) ```bash -python train.py --config configs/train_joint.yaml +python train.py --config configs/train.yaml ``` ### 멀티 GPU 학습 (DDP, 로컬/서버) ```bash python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=8 \ - train.py --config configs/train_joint.yaml + train.py --config configs/train.yaml ``` ### 추론/평가 ```bash python inference.py \ - --config configs/train_joint.yaml \ + --config configs/train.yaml \ --checkpoint /path/to/checkpoint.pt \ --device cuda ``` diff --git a/configs/overfit_test.yaml b/configs/overfit_test.yaml deleted file mode 100644 index 1f87db6..0000000 --- a/configs/overfit_test.yaml +++ /dev/null @@ -1,116 +0,0 @@ -# FlowFix Overfitting Test Configuration -# Based on successful model: flowfix_20251104_151853 -# Uses ONLY simple flow matching loss (no auxiliary losses) - -device: cuda -seed: 42 - -# Data - Small dataset for overfitting test (Dynamic Pose Sampling) -data: - data_dir: train_data # ✅ Relative path to current directory - split_file: null - max_train_samples: 50 # Small dataset for overfitting test - max_val_samples: 5 # Validation samples - num_workers: 4 - # Note: Each epoch samples different poses per PDB for augmentation - -# Model dimensions (same as successful model) -model: - # Protein network - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - protein_hidden_scalar_dim: 64 - protein_hidden_vector_dim: 16 - protein_output_scalar_dim: 64 - protein_output_vector_dim: 16 - protein_num_layers: 4 - - # Ligand network - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - ligand_hidden_scalar_dim: 64 - ligand_hidden_vector_dim: 16 - ligand_output_scalar_dim: 64 - ligand_output_vector_dim: 16 - ligand_num_layers: 4 - - # Interaction network (CRITICAL: 4 layers, not 8!) - interaction_hidden_dim: 256 - interaction_num_heads: 8 - interaction_num_rbf: 32 # Rich distance encoding - interaction_pair_dim: 64 # Expressive pair bias - interaction_num_layers: 4 # ✅ Matches successful model (13M params) - - # Velocity predictor - velocity_hidden_scalar_dim: 64 - velocity_hidden_vector_dim: 16 - velocity_num_layers: 4 - - # Regularization - dropout: 0.1 # ✅ Enable dropout - -# Training - Settings from successful model -training: - num_epochs: 500 - batch_size: 2 - num_timesteps_per_sample: 16 # Sample 16 timesteps per system - val_batch_size: 8 - learning_rate: 0.001 # 1e-3 for fast convergence on small data - gradient_clip: 100.0 - - optimizer: - type: adam - weight_decay: 0.0 # No regularization for overfitting - betas: [0.9, 0.999] - eps: 1.0e-08 - - scheduler: - # CosineAnnealingWarmRestarts - eta_max: 0.01 - T_0: 10 # Restart every 10 epochs - T_mult: 1 # Same period - T_up: 2 # Warmup for 2 epochs - gamma: 0.95 # Decay by 5% each restart - - validation: - frequency: 150 # ✅ Validate every 10 epochs - save_best: true - early_stopping_patience: 100 - -# Sampling (for validation) -sampling: - num_steps: 50 - method: euler - -# Experiment management -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 10 - save_latest: true - keep_last_n: 10 - -# Visualization -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -# WandB logging -wandb: - enabled: true - project: "protein-ligand-flowfix-overfit" - entity: null - name: null # Auto-generated - tags: ["overfit-test", "simple-flow-matching", "5-samples"] - log_gradients: true - log_model_weights: true - log_animations: true - log_parameter_histograms: true - log_parameter_evolution: true - log_layer_analysis: true diff --git a/configs/train_fast_overfit.yaml b/configs/train_fast_overfit.yaml deleted file mode 100644 index e6aa93c..0000000 --- a/configs/train_fast_overfit.yaml +++ /dev/null @@ -1,102 +0,0 @@ -# Fast Overfit Test: Sink Flow, Small Batch, High LR -data: - data_dir: train_data - split_file: train_data/splits_overfit_32.json - max_train_samples: null - max_val_samples: 4 - fix_pose: true - fix_pose_high_rmsd: true - position_noise_scale: 0.05 # Reduced noise for faster overfitting - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - num_timesteps_per_sample: 1 - -model: - architecture: joint - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - hidden_dim: 256 - esm_proj_dim: 128 - self_conditioning: false - dropout: 0.0 - -training: - num_epochs: 300 - batch_size: 4 # 8 updates per epoch (32/4) - val_batch_size: 4 - gradient_clip: 10.0 # Tighter clip for stability with sink flow - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: false - compile: false - - optimizer: - type: muon - muon: - lr: 0.05 # Higher LR for fast overfit - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.005 - adamw: - lr: 0.001 # Higher AdamW LR too - weight_decay: 0.0 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.0001 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 20 - early_stopping_patience: 1000 - -sampling: - num_steps: 20 - method: euler - schedule: uniform - t_end: 1.0 - -experiment: - base_dir: save - -checkpoint: - save_freq: 20 - save_latest: true - keep_last_n: 3 - -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -wandb: - enabled: true - project: "protein-ligand-flowfix" - name: "overfit-fast-sink" - tags: ["overfit-test", "sink-flow", "fast-overfit"] - log_gradients: false - log_model_weights: false - log_animations: true diff --git a/configs/train_joint.yaml b/configs/train_joint.yaml deleted file mode 100644 index ad8d3df..0000000 --- a/configs/train_joint.yaml +++ /dev/null @@ -1,143 +0,0 @@ -# FlowFix Joint Graph Architecture Training Configuration -# Uses joint protein-ligand graph with cross-edges instead of separate encoders + attention - -device: cuda -seed: 42 - -# Data -data: - data_dir: train_data - split_file: train_data/splits.json - max_train_samples: null - max_val_samples: null - num_workers: 8 - loading_mode: lazy - # Timestep sampling for flow matching - timestep_sampling: - method: uniform # uniform for pose refinement (t=0 is structure, not noise) - mu: 0.8 # unused for uniform - sigma: 1.7 # unused for uniform - num_timesteps_per_sample: 1 # K=1 for debugging (was 4) - -# Model - Joint Graph Architecture -model: - architecture: joint # "separate" (default) or "joint" (new) - - # Protein input dims (unchanged) - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - - # Ligand input dims (unchanged) - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - - # Joint network hidden dims (increased from 64/16 for better capacity) - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 # Unified: message passing + velocity prediction in one network - - # Cross-edge parameters (protein-ligand) - cross_edge_distance_cutoff: 6.0 # Angstroms - cross_edge_max_neighbors: 16 # Max neighbors per node (KNN cap within cutoff) - cross_edge_num_rbf: 32 # RBF features for cross-edge distances - - # Intra-edge parameters (dynamic within protein/ligand, supplements pre-computed edges) - intra_edge_distance_cutoff: 6.0 # Angstroms - intra_edge_max_neighbors: 16 # Max neighbors per node - - # Time conditioning hidden dim - hidden_dim: 256 - - # ESM embedding integration - esm_proj_dim: 128 # ESM projection dim (gated concatenation to protein features) - - # Self-conditioning: condition on predicted x1 (crystal) from first pass (50% of training) - # Standard in protein diffusion (RFdiffusion, FrameDiff, FrameFlow) - self_conditioning: true - - dropout: 0.1 - -# Training -training: - num_epochs: 1000 - batch_size: 128 # Per-GPU batch size (8x GPU = 1024 effective) - val_batch_size: 128 - gradient_clip: 100.0 - gradient_accumulation_steps: 1 # batch already large enough (256 × 8 = 2048) - distance_geometry_weight: 0.1 - ema: - enabled: true - decay: 0.999 - compile: true # torch_cluster functions are excluded via @torch._dynamo.disable - compile_options: - mode: default - dynamic: true - - protein_ligand_clash: - enabled: false - ca_threshold: 3.0 - sc_threshold: 2.5 - margin: 1.0 - weight: 0.2 - - optimizer: - type: muon - muon: - lr: 0.005 # Reduced from 0.02 to stabilize training - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.0005 # Cosine decay target (10% of peak) - adamw: - lr: 0.0003 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.00003 # Cosine decay target (10% of peak) - - # Unified LR schedule: linear warmup + plateau + cosine decay - schedule: - warmup_fraction: 0.05 # 5% of total epochs (0 -> max LR) - plateau_fraction: 0.80 # 80% of total epochs (keep max LR) - decay_fraction: 0.15 # 15% of total epochs (cosine max -> min LR) - warmup_epochs: 5 # Fallback for legacy behavior if fractions are not used - - validation: - frequency: 20 - early_stopping_patience: 30 - -# Sampling (Inference/Validation) -sampling: - num_steps: 20 - method: euler - schedule: quadratic - -# Experiment management -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 10 - save_latest: true - keep_last_n: 5 - -# Visualization -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -# WandB logging -wandb: - enabled: true - project: "protein-ligand-flowfix" - entity: null - name: null - tags: ["flow-matching", "protein-ligand", "se3-equivariant", "joint-graph"] - log_gradients: true - log_model_weights: true - log_animations: true diff --git a/configs/train_joint_test.yaml b/configs/train_joint_test.yaml deleted file mode 100644 index 8155ba5..0000000 --- a/configs/train_joint_test.yaml +++ /dev/null @@ -1,139 +0,0 @@ -# FlowFix Joint Graph - Test Configuration (Single A5000 GPU) -# Reduced batch size for 24GB VRAM - -device: cuda -seed: 42 - -# Data -data: - data_dir: train_data - split_file: train_data/splits.json - max_train_samples: null - max_val_samples: null - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - -# Model - Joint Graph Architecture -model: - architecture: joint - - # Protein input dims - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - - # Ligand input dims - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - - # Joint network hidden dims - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - - # Cross-edge parameters - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - - # Intra-edge parameters - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - - # Time conditioning hidden dim - hidden_dim: 256 - - # ESM embedding integration - esm_proj_dim: 128 - - # Self-conditioning - self_conditioning: true - - dropout: 0.1 - -# Training - Reduced for single A5000 (24GB) -training: - num_epochs: 100 # Reduced for quick test - batch_size: 32 # Reduced from 1024 for A5000 (24GB) - val_batch_size: 32 - gradient_clip: 100.0 - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: true - decay: 0.999 - compile: false # Disabled for debugging - compile_options: - mode: default - dynamic: true - - protein_ligand_clash: - enabled: false - ca_threshold: 3.0 - sc_threshold: 2.5 - margin: 1.0 - weight: 0.2 - - optimizer: - type: muon - muon: - lr: 0.02 - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.002 - adamw: - lr: 0.0003 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.00003 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 10 # More frequent validation for test - early_stopping_patience: 30 - -# Sampling (Inference/Validation) -sampling: - num_steps: 10 # Reduced for faster validation - method: euler - schedule: quadratic - -# Experiment management -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 10 - save_latest: true - keep_last_n: 3 - -# Visualization -visualization: - enabled: false # Disabled for quick test - save_animation: false - animation_fps: 10 - num_samples: 1 - -# WandB logging -wandb: - enabled: true - project: "protein-ligand-flowfix" - entity: null - name: null - tags: ["flow-matching", "protein-ligand", "se3-equivariant", "joint-graph", "test"] - log_gradients: true - log_model_weights: true - log_animations: false diff --git a/configs/train_overfit_32.yaml b/configs/train_overfit_32.yaml deleted file mode 100644 index 3e14c04..0000000 --- a/configs/train_overfit_32.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# Overfit Test: 32 training samples, val=subset of train -# Purpose: More stable BatchNorm stats while still testing overfitting - -device: cuda -seed: 42 - -# Data - 32 train, 4 val (subset of train) -data: - data_dir: train_data - split_file: train_data/splits_overfit_32.json - max_train_samples: null - max_val_samples: 4 - fix_pose: true - fix_pose_high_rmsd: false - position_noise_scale: 0.1 - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - num_timesteps_per_sample: 1 - -# Model -model: - architecture: joint - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - hidden_dim: 256 - esm_proj_dim: 128 - self_conditioning: true - dropout: 0.0 - -# Training -training: - num_epochs: 500 - batch_size: 16 # Larger batch for BatchNorm stability - val_batch_size: 4 - gradient_clip: 100.0 - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: false - decay: 0.999 - compile: false - - protein_ligand_clash: - enabled: false - ca_threshold: 3.0 - sc_threshold: 2.5 - margin: 1.0 - weight: 0.2 - - optimizer: - type: muon - muon: - lr: 0.02 - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.002 - adamw: - lr: 0.0003 - weight_decay: 0.0 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.00003 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 50 - early_stopping_patience: 1000 - -# Sampling -sampling: - num_steps: 20 - method: euler - schedule: uniform - t_end: 1.0 - -# Experiment -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 50 - save_latest: true - keep_last_n: 3 - -# Visualization -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -# WandB -wandb: - enabled: true - project: "protein-ligand-flowfix" - entity: null - name: "overfit-test-32-fixed" - tags: ["overfit-test", "batchnorm-fix", "fixed-edges"] - log_gradients: false - log_model_weights: false - log_animations: true diff --git a/configs/train_overfit_high_rmsd.yaml b/configs/train_overfit_high_rmsd.yaml deleted file mode 100644 index 5f1beec..0000000 --- a/configs/train_overfit_high_rmsd.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# Overfit Test: 32 training samples, val=subset of train -# Purpose: More stable BatchNorm stats while still testing overfitting - -device: cuda -seed: 42 - -# Data - 32 train, 4 val (subset of train) -data: - data_dir: train_data - split_file: train_data/splits_overfit_32.json - max_train_samples: null - max_val_samples: 4 - fix_pose: true - fix_pose_high_rmsd: true - position_noise_scale: 0.1 - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - num_timesteps_per_sample: 1 - -# Model -model: - architecture: joint - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - hidden_dim: 256 - esm_proj_dim: 128 - self_conditioning: true - dropout: 0.0 - -# Training -training: - num_epochs: 500 - batch_size: 16 # Larger batch for BatchNorm stability - val_batch_size: 4 - gradient_clip: 100.0 - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: false - decay: 0.999 - compile: false - - protein_ligand_clash: - enabled: false - ca_threshold: 3.0 - sc_threshold: 2.5 - margin: 1.0 - weight: 0.2 - - optimizer: - type: muon - muon: - lr: 0.02 - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.002 - adamw: - lr: 0.0003 - weight_decay: 0.0 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.00003 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 50 - early_stopping_patience: 1000 - -# Sampling -sampling: - num_steps: 20 - method: euler - schedule: uniform - t_end: 1.0 - -# Experiment -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 50 - save_latest: true - keep_last_n: 3 - -# Visualization -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -# WandB -wandb: - enabled: true - project: "protein-ligand-flowfix" - entity: null - name: "overfit-test-high-rmsd" - tags: ["high-rmsd", "overfit-test", "batchnorm-fix", "fixed-edges"] - log_gradients: false - log_model_weights: false - log_animations: true diff --git a/configs/train_overfit_test.yaml b/configs/train_overfit_test.yaml deleted file mode 100644 index cb8073d..0000000 --- a/configs/train_overfit_test.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# Overfit Test: Small dataset (32 samples), train=val -# Purpose: Verify model can learn/overfit to same data - -device: cuda -seed: 42 - -# Data - tiny overfit split -data: - data_dir: train_data - split_file: train_data/splits_overfit_tiny.json # 8 samples, train=val - max_train_samples: null - max_val_samples: 8 # Only validate on 8 samples for speed - fix_pose: true # Use pose 1 (easier, RMSD ~1A) - fix_pose_high_rmsd: false # Disable high RMSD poses for now - position_noise_scale: 0.1 # Add 0.1A noise to x_t for ODE robustness (smaller for stability) - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - num_timesteps_per_sample: 1 - -# Model - same as main config -model: - architecture: joint - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - hidden_dim: 256 - esm_proj_dim: 128 - self_conditioning: true # Keep enabled, helps learning - dropout: 0.0 # No dropout for overfit test - -# Training - fast iteration -training: - num_epochs: 1000 - batch_size: 8 # All 8 samples in one batch - val_batch_size: 8 - gradient_clip: 100.0 - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: false # Disable EMA for faster iteration - decay: 0.999 - compile: false # Disable compile for faster startup - - protein_ligand_clash: - enabled: false - ca_threshold: 3.0 - sc_threshold: 2.5 - margin: 1.0 - weight: 0.2 - - optimizer: - type: muon - muon: - lr: 0.02 # Muon can use high LR - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.002 - adamw: - lr: 0.0003 - weight_decay: 0.0 # No weight decay for overfit test - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.00003 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 100 # Validate every 100 epochs for closer monitoring - early_stopping_patience: 1000 # Don't early stop - -# Sampling -sampling: - num_steps: 20 # More steps for smoother integration - method: euler - schedule: uniform # Linear spacing - t_end: 1.0 # Full integration t=0~1 - -# Experiment -experiment: - base_dir: save - -# Checkpoints -checkpoint: - save_freq: 50 - save_latest: true - keep_last_n: 3 - -# Visualization -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -# WandB -wandb: - enabled: true - project: "protein-ligand-flowfix" - entity: null - name: "overfit-test-8" - tags: ["overfit-test", "tiny"] - log_gradients: false - log_model_weights: false - log_animations: true diff --git a/configs/train_rectified_flow.yaml b/configs/train_rectified_flow.yaml deleted file mode 100644 index f9a6693..0000000 --- a/configs/train_rectified_flow.yaml +++ /dev/null @@ -1,102 +0,0 @@ -# Rectified Flow Test: Stable Linear Interpolation -data: - data_dir: train_data - split_file: train_data/splits_overfit_32.json - max_train_samples: null - max_val_samples: 4 - fix_pose: true - fix_pose_high_rmsd: true - position_noise_scale: 0.05 - num_workers: 4 - loading_mode: lazy - timestep_sampling: - method: uniform - mu: 0.8 - sigma: 1.7 - num_timesteps_per_sample: 1 - -model: - architecture: joint - protein_input_scalar_dim: 76 - protein_input_vector_dim: 31 - protein_input_edge_scalar_dim: 39 - protein_input_edge_vector_dim: 8 - ligand_input_scalar_dim: 122 - ligand_input_edge_scalar_dim: 44 - hidden_scalar_dim: 128 - hidden_vector_dim: 32 - hidden_edge_dim: 128 - joint_num_layers: 6 - cross_edge_distance_cutoff: 6.0 - cross_edge_max_neighbors: 16 - cross_edge_num_rbf: 32 - intra_edge_distance_cutoff: 6.0 - intra_edge_max_neighbors: 16 - hidden_dim: 256 - esm_proj_dim: 128 - self_conditioning: false - dropout: 0.0 - -training: - num_epochs: 300 - batch_size: 4 # 8 updates per epoch (32/4) - val_batch_size: 4 - gradient_clip: 1.0 # Standard clip is fine for Rectified Flow - gradient_accumulation_steps: 1 - distance_geometry_weight: 0.1 - ema: - enabled: false - compile: false - - optimizer: - type: muon - muon: - lr: 0.05 - momentum: 0.95 - weight_decay: 0.0 - min_lr: 0.005 - adamw: - lr: 0.001 - weight_decay: 0.0 - betas: [0.9, 0.999] - eps: 1.0e-08 - min_lr: 0.0001 - - schedule: - warmup_fraction: 0.05 - plateau_fraction: 0.80 - decay_fraction: 0.15 - warmup_epochs: 5 - - validation: - frequency: 20 - early_stopping_patience: 1000 - -sampling: - num_steps: 20 - method: euler - schedule: uniform - t_end: 1.0 - -experiment: - base_dir: save - -checkpoint: - save_freq: 20 - save_latest: true - keep_last_n: 3 - -visualization: - enabled: true - save_animation: true - animation_fps: 10 - num_samples: 1 - -wandb: - enabled: true - project: "protein-ligand-flowfix" - name: "rectified-flow-stable" - tags: ["overfit-test", "rectified-flow", "stable"] - log_gradients: false - log_model_weights: false - log_animations: true diff --git a/configs/train_torsion.yaml b/configs/train_torsion.yaml new file mode 100644 index 0000000..aa95cdc --- /dev/null +++ b/configs/train_torsion.yaml @@ -0,0 +1,117 @@ +# FlowFix SE(3) + Torsion Decomposition Training Configuration +# Output: translation [3] + rotation [3] + torsion [M] instead of per-atom velocity [N, 3] + +device: cuda +seed: 42 + +# Data +data: + data_dir: train_data + split_file: train_data/splits.json + max_train_samples: null + max_val_samples: null + num_workers: 8 + loading_mode: lazy + timestep_sampling: + method: uniform + num_timesteps_per_sample: 1 + +# Model +# Uses ProteinLigandFlowMatchingTorsion (separate encoders + interaction + torsion heads) +# Run with: python train_torsion.py --config configs/train_torsion.yaml +model: + + # Protein input dims + protein_input_scalar_dim: 76 + protein_input_vector_dim: 31 + protein_input_edge_scalar_dim: 39 + protein_input_edge_vector_dim: 8 + + # Ligand input dims + ligand_input_scalar_dim: 122 + ligand_input_edge_scalar_dim: 44 + + # Hidden dims + hidden_scalar_dim: 128 + hidden_vector_dim: 32 + hidden_edge_dim: 128 + + # Time conditioning + hidden_dim: 256 + + # ESM embedding + esm_proj_dim: 128 + + self_conditioning: false + dropout: 0.1 + +# Training +training: + num_epochs: 500 + batch_size: 32 + val_batch_size: 32 + gradient_clip: 10.0 + gradient_accumulation_steps: 1 + + # SE(3) + Torsion loss weights + torsion_loss: + w_trans: 1.0 # Translation MSE weight + w_rot: 1.0 # Rotation MSE weight + w_tor: 1.0 # Torsion circular MSE weight + w_coord: 0.5 # Coordinate reconstruction weight (end-to-end) + + distance_geometry_weight: 0.0 # Disabled for torsion mode (built-in constraints) + + ema: + enabled: true + decay: 0.999 + compile: false # Disable for initial debugging + + optimizer: + type: muon + muon: + lr: 0.005 + momentum: 0.95 + weight_decay: 0.0 + min_lr: 0.0005 + adamw: + lr: 0.0003 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1.0e-08 + min_lr: 0.00003 + + schedule: + warmup_fraction: 0.05 + plateau_fraction: 0.80 + decay_fraction: 0.15 + + validation: + frequency: 20 + early_stopping_patience: 30 + +# Sampling +sampling: + num_steps: 20 + method: euler + schedule: uniform + +# Experiment +experiment: + base_dir: save + +checkpoint: + save_freq: 10 + save_latest: true + keep_last_n: 5 + +visualization: + enabled: true + save_animation: true + animation_fps: 10 + num_samples: 1 + +wandb: + enabled: true + project: "protein-ligand-flowfix" + tags: ["flow-matching", "se3-torsion", "decomposition"] diff --git a/output/cartesian_v4_baseline/README.md b/output/cartesian_v4_baseline/README.md new file mode 100644 index 0000000..7e97b64 --- /dev/null +++ b/output/cartesian_v4_baseline/README.md @@ -0,0 +1,22 @@ +# Cartesian v4 Baseline (rectified-flow-full-v4) + +## Model +- **Architecture**: Joint graph, 6-layer GatingEquivariantLayer +- **Output**: Per-atom Cartesian velocity [N, 3] +- **Epochs**: 500 +- **Checkpoint**: `save/rectified-flow-full-v4/checkpoints/latest.pt` (1.1GB) + +## Results (Full Validation, 200 PDBs) + +| Metric | Before | After | Delta | +|--------|--------|-------|-------| +| Mean RMSD | 3.20 A | 2.64 A | -0.56 A | +| Median RMSD | 3.00 A | 2.22 A | -0.78 A | + +## Files + +- `train_config.yaml` - Training configuration used +- `full_validation_results.json` - Full validation inference results (200 PDBs, all poses) +- `latest_train_valid_5.json` - Small train/valid 5 sample inference +- `train5_allposes_latest.json` - Train 5 all poses inference +- Visualization plots: `reports/assets/` (RMSD distribution, scatter, improvement) diff --git a/inference_results/full_validation_results.json b/output/cartesian_v4_baseline/full_validation_results.json similarity index 100% rename from inference_results/full_validation_results.json rename to output/cartesian_v4_baseline/full_validation_results.json diff --git a/inference_results/latest_train_valid_5.json b/output/cartesian_v4_baseline/latest_train_valid_5.json similarity index 100% rename from inference_results/latest_train_valid_5.json rename to output/cartesian_v4_baseline/latest_train_valid_5.json diff --git a/inference_results/train5_allposes_latest.json b/output/cartesian_v4_baseline/train5_allposes_latest.json similarity index 100% rename from inference_results/train5_allposes_latest.json rename to output/cartesian_v4_baseline/train5_allposes_latest.json diff --git a/configs/train_rectified_flow_full.yaml b/output/cartesian_v4_baseline/train_config.yaml similarity index 66% rename from configs/train_rectified_flow_full.yaml rename to output/cartesian_v4_baseline/train_config.yaml index 476b606..311b6ba 100644 --- a/configs/train_rectified_flow_full.yaml +++ b/output/cartesian_v4_baseline/train_config.yaml @@ -1,9 +1,8 @@ -# Rectified Flow Full Training data: data_dir: train_data split_file: train_data/splits.json max_train_samples: null - max_val_samples: null # Use full validation set + max_val_samples: null fix_pose: false fix_pose_high_rmsd: false position_noise_scale: 0.05 @@ -14,7 +13,6 @@ data: mu: 0.8 sigma: 1.7 num_timesteps_per_sample: 2 - model: architecture: joint protein_input_scalar_dim: 76 @@ -35,64 +33,60 @@ model: hidden_dim: 384 esm_proj_dim: 128 self_conditioning: false - dropout: 0.1 # Add dropout for generalization - + dropout: 0.1 training: num_epochs: 500 - batch_size: 32 # Balanced for v4 (1.5x scale) - val_batch_size: 32 # Increased for faster validation + batch_size: 32 + val_batch_size: 32 gradient_clip: 1.0 gradient_accumulation_steps: 1 distance_geometry_weight: 0.1 ema: - enabled: true # Enable EMA for stable implementation + enabled: true decay: 0.999 compile: false - optimizer: - type: adamw # Safe default for full run, or muon if tested + type: adamw adamw: lr: 0.0001 weight_decay: 0.0001 - betas: [0.9, 0.999] + betas: + - 0.9 + - 0.999 eps: 1.0e-08 - min_lr: 1.0e-6 - + min_lr: 1.0e-06 schedule: warmup_fraction: 0.05 - plateau_fraction: 0.80 + plateau_fraction: 0.8 decay_fraction: 0.15 warmup_epochs: 5 - validation: - frequency: 20 # Validate every 20 epochs + frequency: 20 early_stopping_patience: 50 - sampling: num_steps: 20 method: euler schedule: uniform t_end: 1.0 - experiment: base_dir: save - checkpoint: save_freq: 5 save_latest: true keep_last_n: 5 - visualization: enabled: true - save_animation: false # Disable animation save to save space/time on full run + save_animation: false animation_fps: 10 num_samples: 4 - wandb: enabled: true - project: "protein-ligand-flowfix" - name: "rectified-flow-full-v4" - tags: ["full-train", "rectified-flow", "stable"] + project: protein-ligand-flowfix + name: rectified-flow-full-v4 + tags: + - full-train + - rectified-flow + - stable log_gradients: false log_model_weights: false log_animations: true diff --git a/reports/assets/hist_improvement.png b/reports/assets/hist_improvement.png deleted file mode 100644 index 0d5a1da..0000000 Binary files a/reports/assets/hist_improvement.png and /dev/null differ diff --git a/reports/assets/hist_rmsd_distribution.png b/reports/assets/hist_rmsd_distribution.png deleted file mode 100644 index 4b04ad8..0000000 Binary files a/reports/assets/hist_rmsd_distribution.png and /dev/null differ diff --git a/reports/assets/scatter_atoms_vs_improvement.png b/reports/assets/scatter_atoms_vs_improvement.png deleted file mode 100644 index e1bbd6d..0000000 Binary files a/reports/assets/scatter_atoms_vs_improvement.png and /dev/null differ diff --git a/reports/assets/scatter_init_vs_final_pdb.png b/reports/assets/scatter_init_vs_final_pdb.png deleted file mode 100644 index 3e04255..0000000 Binary files a/reports/assets/scatter_init_vs_final_pdb.png and /dev/null differ diff --git a/reports/assets/scatter_init_vs_final_pose.png b/reports/assets/scatter_init_vs_final_pose.png deleted file mode 100644 index 43d3ba3..0000000 Binary files a/reports/assets/scatter_init_vs_final_pose.png and /dev/null differ diff --git a/reports/assets/scatter_init_vs_improvement.png b/reports/assets/scatter_init_vs_improvement.png deleted file mode 100644 index ec541fa..0000000 Binary files a/reports/assets/scatter_init_vs_improvement.png and /dev/null differ diff --git a/reports/architecture.md b/reports/cartesian/architecture.md similarity index 93% rename from reports/architecture.md rename to reports/cartesian/architecture.md index b29eb86..5da8a40 100644 --- a/reports/architecture.md +++ b/reports/cartesian/architecture.md @@ -167,14 +167,14 @@ flowchart TB Pre-trained PLM의 residue-level embedding을 protein features에 통합. ```mermaid -flowchart LR - ESMC["ESMC 600M
[N, 1152]"] --> P1["MLP
1152->128->128"] - ESM3["ESM3
[N, 1536]"] --> P2["MLP
1536->128->128"] +flowchart TB + ESMC["ESMC 600M [N, 1152]"] --> P1["MLP 1152->128->128"] + ESM3["ESM3 [N, 1536]"] --> P2["MLP 1536->128->128"] P1 --> WS["Weighted Sum
softmax(esm_weight)"] P2 --> WS - WS --> GATE["esm_gate
sigmoid MLP"] + WS --> GATE["esm_gate (sigmoid MLP)"] PX["protein.x [N,76]"] --> CAT["Gated Concat"] GATE --> CAT @@ -204,26 +204,23 @@ flowchart LR ### 3.1 Flow Matching Loss ```mermaid -flowchart LR +flowchart TB subgraph sample["Sampling"] X0["x_0 (docked)"] X1["x_1 (crystal)"] T["t ~ Uniform(0,1)"] end - INTERP["x_t = (1-t)*x0 + t*x1"] - FWD["v_pred = model(prot, lig_t, t)"] - VT["v_true = x1 - x0"] - MSE["L_flow = MSE(v_pred, v_true)"] - DG["L_dg = dist geometry
bond constraints
time-weighted"] - TOTAL["L = L_flow + 0.1 * L_dg"] - - X0 --> INTERP + X0 --> INTERP["x_t = (1-t)*x0 + t*x1"] X1 --> INTERP T --> INTERP - INTERP --> FWD --> MSE --> TOTAL - VT --> MSE - DG --> TOTAL + + INTERP --> FWD["v_pred = model(prot, lig_t, t)"] + FWD --> MSE["L_flow = MSE(v_pred, v_true)"] + VT["v_true = x1 - x0"] --> MSE + + MSE --> TOTAL["L = L_flow + 0.1 * L_dg"] + DG["L_dg = dist geometry
bond constraints, time-weighted"] --> TOTAL ``` ### 3.2 Training Configuration @@ -355,8 +352,8 @@ src/utils/ model_builder.py Config -> model construction configs/ - train_rectified_flow_full.yaml # v4 config (trained model) - train_joint.yaml # Joint architecture template + train.yaml # Cartesian training config + train_torsion.yaml # SE(3) + Torsion training config train.py Training loop inference.py Inference script diff --git a/inference_results/full_validation_results/hist_improvement.png b/reports/cartesian/assets/hist_improvement.png similarity index 100% rename from inference_results/full_validation_results/hist_improvement.png rename to reports/cartesian/assets/hist_improvement.png diff --git a/inference_results/full_validation_results/hist_rmsd_distribution.png b/reports/cartesian/assets/hist_rmsd_distribution.png similarity index 100% rename from inference_results/full_validation_results/hist_rmsd_distribution.png rename to reports/cartesian/assets/hist_rmsd_distribution.png diff --git a/reports/assets/refinement_example_1d1p.gif b/reports/cartesian/assets/refinement_example_1d1p.gif similarity index 100% rename from reports/assets/refinement_example_1d1p.gif rename to reports/cartesian/assets/refinement_example_1d1p.gif diff --git a/inference_results/full_validation_results/scatter_atoms_vs_improvement.png b/reports/cartesian/assets/scatter_atoms_vs_improvement.png similarity index 100% rename from inference_results/full_validation_results/scatter_atoms_vs_improvement.png rename to reports/cartesian/assets/scatter_atoms_vs_improvement.png diff --git a/inference_results/full_validation_results/scatter_init_vs_final_pdb.png b/reports/cartesian/assets/scatter_init_vs_final_pdb.png similarity index 100% rename from inference_results/full_validation_results/scatter_init_vs_final_pdb.png rename to reports/cartesian/assets/scatter_init_vs_final_pdb.png diff --git a/inference_results/full_validation_results/scatter_init_vs_final_pose.png b/reports/cartesian/assets/scatter_init_vs_final_pose.png similarity index 100% rename from inference_results/full_validation_results/scatter_init_vs_final_pose.png rename to reports/cartesian/assets/scatter_init_vs_final_pose.png diff --git a/inference_results/full_validation_results/scatter_init_vs_improvement.png b/reports/cartesian/assets/scatter_init_vs_improvement.png similarity index 100% rename from inference_results/full_validation_results/scatter_init_vs_improvement.png rename to reports/cartesian/assets/scatter_init_vs_improvement.png diff --git a/reports/cartesian/results.md b/reports/cartesian/results.md new file mode 100644 index 0000000..94c818d --- /dev/null +++ b/reports/cartesian/results.md @@ -0,0 +1,89 @@ +# Cartesian v4 Baseline Results + +> **Model**: `rectified-flow-full-v4` (joint graph, 8-layer GatingEquivariantLayer, ~13M params) +> +> **Output**: Per-atom Cartesian velocity [N, 3] +> +> **Evaluation**: 200 PDBs, 11,543 poses, 20-step Euler ODE, EMA applied + +--- + +## Summary Metrics + +| Metric | Before Refinement | After Refinement | Change | +|--------|-------------------|------------------|--------| +| Mean RMSD | 3.20 A | 2.64 A | -0.56 A | +| Median RMSD | 3.00 A | 2.22 A | -0.78 A | +| Success rate (<2A) | 30.4% | 44.6% | +14.2%p | +| Success rate (<1A) | 8.7% | 13.5% | +4.8%p | +| Success rate (<0.5A) | 0.7% | 1.3% | +0.6%p | +| Improved poses | - | 75.2% | - | + +--- + +## Visualizations + +### RMSD Distribution: Before vs After + +![RMSD Distribution](assets/hist_rmsd_distribution.png) + +Refinement 후 분포가 전체적으로 왼쪽(낮은 RMSD)으로 이동. Mean 3.20A -> 2.64A. + +### Per-Pose: Initial vs Final RMSD + +![Per-Pose Scatter](assets/scatter_init_vs_final_pose.png) + +대각선 아래 = 개선된 pose. **75.2%의 pose가 개선됨.** + +### Per-PDB: Average Initial vs Final RMSD + +![Per-PDB Scatter](assets/scatter_init_vs_final_pdb.png) + +PDB 단위로 평균하면 **200개 중 178개 (89.0%)가 개선됨.** 대부분의 target에서 일관된 개선. + +### RMSD Improvement Distribution + +![Improvement Distribution](assets/hist_improvement.png) + +Mean improvement: 0.56A, Median: 0.25A. 양의 방향(개선)으로 skewed. + +### Initial RMSD vs Improvement + +![Init vs Improvement](assets/scatter_init_vs_improvement.png) + +Initial RMSD가 클수록 improvement 폭도 큼. 단, 매우 큰 perturbation (>8A)에서는 효과 감소. + +### Ligand Size vs Improvement + +![Ligand Size](assets/scatter_atoms_vs_improvement.png) + +원자 수가 적은 ligand에서 개선 폭이 크고 분산도 큼. 큰 ligand는 상대적으로 안정적이나 개선 폭이 작음. + +### Refinement Trajectory Example (PDB: 1d1p) + +![Refinement Example](assets/refinement_example_1d1p.gif) + +4개 시점에서의 refinement 결과. Green = crystal, Red = current, Purple circle = initial docked pose. + +--- + +## Training Configuration + +| Parameter | Value | +|-----------|-------| +| Architecture | Joint graph (8x GatingEquivariantLayer) | +| Hidden irreps | `192x0e + 48x1o + 48x1e` (480d) | +| Edge cutoff (PL cross) | 6.0 A, max 16 neighbors | +| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) | +| Schedule | Linear warmup (5%) + Plateau (80%) + Cosine decay (15%) | +| Loss | Velocity MSE + Distance geometry loss (weight=0.1) | +| EMA | decay=0.999 | +| Batch size | 32 | +| Epochs | 500 | +| Dropout | 0.1 | + +## Data + +- **Checkpoint**: `save/rectified-flow-full-v4/checkpoints/latest.pt` +- **Config**: `output/cartesian_v4_baseline/train_config.yaml` +- **Inference results**: `output/cartesian_v4_baseline/` diff --git a/reports/progress.md b/reports/progress.md index 83e2061..039d8c0 100644 --- a/reports/progress.md +++ b/reports/progress.md @@ -6,168 +6,74 @@ --- -## Overview +## Approach Comparison -FlowFix는 docking 결과로 얻어진 protein-ligand binding pose를 crystal structure에 가깝게 refinement하는 모델입니다. -SE(3)-equivariant message passing network 위에서 flow matching으로 velocity field를 학습하여, perturbed pose -> crystal pose로의 ODE trajectory를 생성합니다. - -### Key Design Choices - -| Component | Choice | Rationale | -|-----------|--------|-----------| -| Representation | Joint protein-ligand graph | Cross-edge로 protein context 직접 전달 | -| Equivariance | cuEquivariance tensor product | SE(3) symmetry 보존, GPU-accelerated | -| Interaction | Direct message passing (no attention) | 단순하고 효율적인 protein-ligand interaction | -| Generative model | Flow matching (rectified flow) | Stable training, fast sampling | -| Protein embedding | ESMC 600M + ESM3 (weighted, gated) | Pre-trained sequence representation | -| Optimizer | Muon + AdamW hybrid | 2D weight에 Muon, 나머지 AdamW | +| | Cartesian (v4) | SE(3) + Torsion | +|--|----------------|-----------------| +| **Output** | Per-atom velocity [N, 3] | Trans [3] + Rot [3] + Torsion [M] | +| **Dimension** | 3N (~132D) | 3+3+M (~14D) | +| **Status** | Trained (500 epochs) | Implemented, not yet trained | +| **Mean RMSD** | 3.20A -> 2.64A | - | +| **Success <2A** | 30.4% -> 44.6% | - | +| **Improved poses** | 75.2% | - | --- -## Current Results (v4 - Baseline) +## Cartesian (v4 Baseline) -**Model**: `rectified-flow-full-v4` (joint graph, 6-layer GatingEquivariantLayer) -**Evaluation**: 200 PDBs, 11,543 poses, 20-step Euler ODE, EMA applied +- **Architecture**: Joint graph, 8x GatingEquivariantLayer, ~13M params +- **Results**: [cartesian/results.md](cartesian/results.md) +- **Architecture detail**: [cartesian/architecture.md](cartesian/architecture.md) +- **Inference data**: `output/cartesian_v4_baseline/` -### Summary Metrics +### Key Results -| Metric | Before Refinement | After Refinement | Change | -|--------|-------------------|------------------|--------| +| Metric | Before | After | Change | +|--------|--------|-------|--------| | Mean RMSD | 3.20 A | 2.64 A | -0.56 A | | Median RMSD | 3.00 A | 2.22 A | -0.78 A | | Success rate (<2A) | 30.4% | 44.6% | +14.2%p | -| Success rate (<1A) | 8.7% | 13.5% | +4.8%p | -| Success rate (<0.5A) | 0.7% | 1.3% | +0.6%p | -| Improved poses | - | 75.2% | - | - -### RMSD Distribution: Before vs After - -![RMSD Distribution](assets/hist_rmsd_distribution.png) - -Refinement 후 분포가 전체적으로 왼쪽(낮은 RMSD)으로 이동. Mean 3.20A -> 2.64A. - -### Per-Pose: Initial vs Final RMSD - -![Per-Pose Scatter](assets/scatter_init_vs_final_pose.png) - -대각선 아래 = 개선된 pose. **75.2%의 pose가 개선됨.** - -### Per-PDB: Average Initial vs Final RMSD - -![Per-PDB Scatter](assets/scatter_init_vs_final_pdb.png) - -PDB 단위로 평균하면 **200개 중 178개 (89.0%)가 개선됨.** 대부분의 target에서 일관된 개선. - -### RMSD Improvement Distribution -![Improvement Distribution](assets/hist_improvement.png) - -Mean improvement: 0.56A, Median: 0.25A. 양의 방향(개선)으로 skewed. - -### Initial RMSD vs Improvement - -![Init vs Improvement](assets/scatter_init_vs_improvement.png) - -Initial RMSD가 클수록 improvement 폭도 큼. 단, 매우 큰 perturbation (>8A)에서는 효과 감소. - -### Ligand Size vs Improvement - -![Ligand Size](assets/scatter_atoms_vs_improvement.png) - -원자 수가 적은 ligand에서 개선 폭이 크고 분산도 큼. 큰 ligand는 상대적으로 안정적이나 개선 폭이 작음. - -### Refinement Trajectory Example (PDB: 1d1p) +--- -![Refinement Example](assets/refinement_example_1d1p.gif) +## SE(3) + Torsion -4개 시점에서의 refinement 결과. Green = crystal, Red = current, Purple circle = initial docked pose. -Velocity field (orange arrows)를 따라 crystal structure 방향으로 이동. +- **Architecture**: Shared backbone + decomposed output heads (trans/rot/torsion) +- **Results**: [torsion/results.md](torsion/results.md) +- **Architecture detail**: [torsion/architecture.md](torsion/architecture.md) ---- +### Expected Benefits -## Architecture - -> 상세 아키텍처 문서: [architecture.md](architecture.md) - -```mermaid -flowchart LR - subgraph prep["Preprocessing"] - P["Protein"] --> ESM["ESM Integration"] - T["time t"] --> TIME["Time Embedding"] - end - - subgraph joint["Joint Graph"] - ESM --> PPROJ["Protein Proj"] - L["Ligand"] --> LPROJ["Ligand Proj"] - PPROJ --> MERGE["Merge + 4 Edge Types"] - LPROJ --> MERGE - end - - subgraph mp["Message Passing"] - MERGE --> LAYERS["8x GatingEquivLayer
+ Time AdaLN"] - TIME --> LAYERS - end - - subgraph out["Output"] - LAYERS --> EXT["Extract Ligand"] - EXT --> VEL["EquivariantMLP
-> v [N_l, 3]"] - end - - style prep fill:#e8f5e9,stroke:#2E7D32 - style joint fill:#fff3e0,stroke:#E65100 - style mp fill:#fce4ec,stroke:#C62828 - style out fill:#f3e5f5,stroke:#6A1B9A -``` - -**Joint graph architecture:** -- Protein + ligand를 하나의 그래프로 합쳐서 **direct message passing** (cross-attention 없음) -- 4 edge types: PP (pre-computed), LL bonds (pre-computed), LL intra (dynamic), PL cross (dynamic, 6.0A cutoff) -- 8x GatingEquivariantLayer with time conditioning via AdaLN -- Velocity output: ligand slice 추출 후 EquivariantMLP -> 3D velocity - -### Training Setup - -| Parameter | Value | -|-----------|-------| -| Architecture | Joint graph (8x GatingEquivariantLayer) | -| Hidden irreps | `192x0e + 48x1o + 48x1e` (480d) | -| Edge cutoff (PL cross) | 6.0 A, max 16 neighbors | -| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) | -| Schedule | Linear warmup (5%) + Plateau (80%) + Cosine decay (15%) | -| Loss | Velocity MSE + Distance geometry loss (weight=0.1) | -| EMA | decay=0.999 | -| Batch size | 32 | -| Epochs | 500 | -| Dropout | 0.1 | +- 3-10x dimension reduction (물리적 자유도만 학습) +- Bond length/angle 자동 보존 +- Interpretable decomposition (translation vs rotation vs torsion) --- ## TODO / Next Steps -- [ ] Success rate <2A 목표: 60%+ (현재 44.6%) -- [ ] Self-conditioning 효과 ablation -- [ ] Torsion space decomposition 적용 (utilities exist in ligand_feat.py, not yet integrated) -- [ ] Multi-step refinement (iterative application) -- [ ] Larger dataset / cross-dataset generalization -- [ ] Inference speed optimization (fewer ODE steps) +- [ ] Torsion 모델 학습 및 Cartesian과 비교 +- [ ] Success rate <2A 목표: 60%+ +- [ ] Self-conditioning ablation +- [ ] Multi-step refinement +- [ ] Inference speed optimization --- ## Changelog +### 2026-03-06 - SE(3) + Torsion Implementation +- Translation [3] + Rotation [3] + Torsion [M] decomposition 구현 +- 별도 model/dataset/loss/sampling/trainer 파일 분리 +- Codebase cleanup (dead files, legacy configs 정리) + ### 2026-02-18 - v4 Baseline Results - Full validation on 200 PDBs (11,543 poses) - 20-step Euler ODE with EMA model - Mean RMSD: 3.20A -> 2.64A, Success rate <2A: 30.4% -> 44.6% -### 2026-02 - Joint Graph Architecture (v4, current) -- Joint protein-ligand graph with 4 edge types (PP, LL, LL intra, PL cross) -- 8x GatingEquivariantLayer with time AdaLN conditioning -- cuEquivariance tensor product for SE(3) equivariance +### 2026-02 - Joint Graph Architecture (v4) +- Joint protein-ligand graph with 4 edge types +- 8x GatingEquivariantLayer with time AdaLN +- cuEquivariance tensor product - Muon + AdamW hybrid optimizer -- EMA (decay=0.999) for inference - -### 2024-11 - SE(3) + Torsion Decomposition (not used in training) -- Translation [3D] + Rotation [3D] + Torsion [M] decomposition utilities implemented -- Not integrated into model/training - current model uses Cartesian velocity -- Chain-wise ESM embedding support added diff --git a/reports/torsion/architecture.md b/reports/torsion/architecture.md new file mode 100644 index 0000000..aaeb32d --- /dev/null +++ b/reports/torsion/architecture.md @@ -0,0 +1,115 @@ +# SE(3) + Torsion Decomposition Architecture + +> **Model**: `ProteinLigandFlowMatchingTorsion` +> +> **Output**: Translation [3] + Rotation [3] + Torsion [M] (instead of per-atom velocity [N, 3]) + +--- + +## 1. Overview + +Molecular pose를 SE(3) + Torsion 공간으로 분해하여, Cartesian 3N차원 대신 (3 + 3 + M)차원에서 velocity field를 학습합니다. + +- **Translation [3D]**: 분자 중심의 이동 +- **Rotation [3D, SO(3)]**: 분자 전체 회전 (axis-angle) +- **Torsion [M]**: M개 rotatable bond의 회전각 + +### Dimension Reduction + +일반적인 drug molecule (N=44 atoms, M=8 rotatable bonds): +- Cartesian: 44 x 3 = **132D** +- Torsion: 3 + 3 + 8 = **14D** (9.4x reduction) + +--- + +## 2. Architecture + +### 2.1 Backbone (Shared with Cartesian) + +Encoder, interaction network, velocity blocks는 base `ProteinLigandFlowMatching`과 동일. + +``` +ProteinLigandFlowMatchingTorsion (inherits ProteinLigandFlowMatching) +├── [shared] protein_encoder, ligand_encoder +├── [shared] cross_attention, velocity_blocks +├── [new] translation_head: EquivariantMLP -> 1x1o [3D] +├── [new] rotation_head: EquivariantMLP -> 1x1o [3D] +└── [new] torsion_head: MLP(2*scalar_dim -> 1) per bond +``` + +### 2.2 Output Heads + +**Translation & Rotation**: Global pooling (scatter_mean) -> EquivariantMLP -> 1x1o + +``` +h_scalar[mol_i] -> scatter_mean -> EquivariantMLP -> translation [3] +h_scalar[mol_i] -> scatter_mean -> EquivariantMLP -> rotation [3] +``` + +**Torsion**: Edge-level prediction (DiffDock approach) + +``` +For each rotatable bond (src, dst): + cat(h_scalar[src], h_scalar[dst]) -> MLP -> scalar angle +``` + +### 2.3 Application Order + +Torsion -> Translation -> Rotation (DiffDock convention) + +``` +1. Apply torsion angles (Rodrigues rotation per bond) +2. Translate center of mass +3. Rotate around center of mass (axis-angle via Rodrigues) +``` + +--- + +## 3. Loss Function + +``` +L = w_trans * L_trans + w_rot * L_rot + w_tor * L_tor + w_coord * L_coord + +L_trans = MSE(pred_trans, target_trans) +L_rot = MSE(pred_rot, target_rot) +L_tor = circular_MSE(pred_tor, target_tor) # atan2(sin, cos) wrapping +L_coord = MSE(reconstruct(x0, pred), x1) # end-to-end reconstruction +``` + +Default weights: `w_trans=1.0, w_rot=1.0, w_tor=1.0, w_coord=0.5` + +--- + +## 4. Training Configuration + +| Parameter | Value | +|-----------|-------| +| Architecture | Torsion (shared backbone + decomposed heads) | +| Hidden scalar/vector | 128 / 32 | +| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) | +| Schedule | Warmup 5% + Plateau 80% + Cosine 15% | +| Loss | Trans MSE + Rot MSE + Circular Torsion + Coord Recon | +| EMA | decay=0.999 | +| Batch size | 32 | +| Epochs | 500 | +| ODE sampling | 20-step Euler | + +--- + +## 5. File Map + +``` +src/models/flowmatching_torsion.py # ProteinLigandFlowMatchingTorsion +src/data/dataset_torsion.py # FlowFixTorsionDataset, collate_torsion_batch +src/utils/losses_torsion.py # compute_se3_torsion_loss, reconstruct_coords +src/utils/sampling_torsion.py # sample_trajectory_torsion +train_torsion.py # FlowFixTorsionTrainer +configs/train_torsion.yaml # Training config +``` + +--- + +## 6. References + +- DiffDock (Corso et al., 2023): Product space diffusion on R3 x SO(3) x T^M +- Torsional Diffusion (Jing et al., 2022): Diffusion on torsion angles diff --git a/reports/torsion/results.md b/reports/torsion/results.md new file mode 100644 index 0000000..33f050e --- /dev/null +++ b/reports/torsion/results.md @@ -0,0 +1,33 @@ +# SE(3) + Torsion Results + +> **Status**: Not yet trained +> +> **Model**: `ProteinLigandFlowMatchingTorsion` +> +> **Output**: Translation [3] + Rotation [3] + Torsion [M] + +--- + +## Summary Metrics + +| Metric | Before Refinement | After Refinement | Change | +|--------|-------------------|------------------|--------| +| Mean RMSD | - | - | - | +| Median RMSD | - | - | - | +| Success rate (<2A) | - | - | - | +| Success rate (<1A) | - | - | - | +| Improved poses | - | - | - | + +> 학습 완료 후 업데이트 예정 + +--- + +## Visualizations + +> 학습 완료 후 `assets/`에 추가 예정 + +--- + +## Training Log + +> WandB 링크 및 학습 커브 추가 예정 diff --git a/scripts/README.md b/scripts/README.md index 8a1c2fb..45b2876 100755 --- a/scripts/README.md +++ b/scripts/README.md @@ -15,6 +15,8 @@ scripts/ - `visualize_loss.py`: 학습 로그 기반 loss 시각화 - `visualize_trajectory.py`: 샘플링 trajectory 시각화 +- `infer_full_validation.py`: 전체 validation set 추론 +- `infer_small_train_valid.py`: 소규모 train/valid 추론 ## `scripts/data/` @@ -24,14 +26,11 @@ scripts/ - `generate_test_data.py`: 테스트 데이터 생성 - `inspect_data.py`: 데이터 점검 - `verify_test_data.py`: 테스트 데이터 검증 -- `run_preprocess_pdbbind.sh`: 전처리 실행용 쉘 래퍼 ## `scripts/slurm/` 환경(경로, 파티션, GPU 수)은 클러스터마다 다르므로, 본인 환경에 맞게 `PYTHON`, `PROJECT_DIR`, `#SBATCH ...`를 조정해서 사용하세요. -- `run_train_joint_test.sh`: 단일 GPU 빠른 테스트 -- `run_train_joint.sh`: 멀티 GPU(DDP) 학습 +- `run_train_full.sh`: 멀티 GPU(DDP) 학습 - `run_inference.sh`: 체크포인트 추론/평가 -- `run_visualize_trajectory.sh`: trajectory 시각화 - +- `run_full_val_inference.sh`: 전체 validation 추론 diff --git a/scripts/analysis/infer_full_validation.py b/scripts/analysis/infer_full_validation.py index 6e419c2..0311692 100755 --- a/scripts/analysis/infer_full_validation.py +++ b/scripts/analysis/infer_full_validation.py @@ -310,7 +310,7 @@ def save_partial(signum, frame): def main(): p = argparse.ArgumentParser(description="Full validation inference") - p.add_argument("--config", default="configs/train_joint.yaml") + p.add_argument("--config", default="configs/train.yaml") p.add_argument("--checkpoint", default="save/rectified-flow-full-v4/checkpoints/latest.pt") p.add_argument("--output", default="inference_results/full_validation_results.json") p.add_argument("--no_ema", action="store_true") diff --git a/scripts/debug_ckpt.py b/scripts/debug_ckpt.py deleted file mode 100755 index 97b37b4..0000000 --- a/scripts/debug_ckpt.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -import torch -import yaml -import argparse -from pathlib import Path -from torch_geometric.loader import DataLoader -from src.data.dataset import FlowFixDataset, collate_flowfix_batch -from src.models.flowmatching import ProteinLigandFlowMatchingJoint -from src.utils.relaxation import RelaxationEngine - -def test_checkpoint(checkpoint_path, config_path, split="val", pdb_id=None, do_relax=False): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"\n" + "="*60) - print(f"Testing Checkpoint: {checkpoint_path}") - print(f"Config: {config_path}") - print("="*60) - - # Load config - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - - # Load dataset - dataset = FlowFixDataset( - data_dir=config['data']['data_dir'], - split_file=config['data'].get('split_file'), - split=split, - max_samples=None, - seed=42, - fix_pose=config['data'].get('fix_pose', False), - fix_pose_high_rmsd=config['data'].get('fix_pose_high_rmsd', False) - ) - - # Selection - if pdb_id: - try: - target_idx = dataset.pdb_ids.index(pdb_id) - sample = dataset[target_idx] - except ValueError: - print(f"PDB {pdb_id} not in val split, using first entry") - sample = dataset[0] - else: - sample = dataset[0] - - # Model - import inspect - sig = inspect.signature(ProteinLigandFlowMatchingJoint.__init__) - valid_params = sig.parameters.keys() - model_config = {k: v for k, v in config['model'].items() if k in valid_params} - - model = ProteinLigandFlowMatchingJoint(**model_config).to(device) - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - model.load_state_dict(checkpoint['model_state_dict'], strict=False) - - # TEST: Keep in train mode to see if BatchNorm is the culprit - model.eval() - model.eval() - print("WARNING: Running in MODEL.TRAIN() mode to avoid BatchNorm eval mismatch") - - # Data - batch = collate_flowfix_batch([sample]) - ligand_batch = batch['ligand_graph'].to(device) - protein_batch = batch['protein_graph'].to(device) - ligand_coords_x0 = batch['ligand_coords_x0'].to(device) - ligand_coords_x1 = batch['ligand_coords_x1'].to(device) - - initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1) ** 2)).item() - print(f"\nPDB: {batch['pdb_ids'][0]}") - print(f"Initial RMSD: {initial_rmsd:.4f} Å") - - # ODE Integration - num_steps = 20 - current_coords = ligand_coords_x0.clone() - timesteps = torch.linspace(0.0, 1.0, num_steps + 1) - - print(f"\n{'Step':>4} | {'t':>5} | {'RMSD':>8} | {'|v|':>8} | {'|v_tgt|':>8} | {'Sim':>8}") - print("-" * 60) - - for step in range(num_steps): - t_current = timesteps[step] - t_next = timesteps[step+1] - dt = t_next - t_current - - t = torch.ones(1, device=device) * t_current - - with torch.no_grad(): - ligand_batch.pos = current_coords - velocity = model(protein_batch, ligand_batch, t) - - # target v in flow matching sense: (x1 - xt) / (1-t) - eps = 1e-5 - v_target = (ligand_coords_x1 - current_coords) / (1.0 - t_current + eps) - - rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item() - v_norm = torch.norm(velocity, dim=-1).mean().item() - v_target_norm = torch.norm(v_target, dim=-1).mean().item() - - v_flat = velocity.view(-1) - target_flat = v_target.view(-1) - sim = torch.nn.functional.cosine_similarity(v_flat, target_flat, dim=0).item() - - if step % 5 == 0 or step == num_steps - 1: - print(f"{step:4d} | {t_current:5.2f} | {rmsd:8.4f} | {v_norm:8.4f} | {v_target_norm:8.4f} | {sim:8.4f}") - - current_coords = current_coords + velocity * dt - - final_rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item() - print("-" * 60) - print(f"Final RMSD: {final_rmsd:.4f} Å") - print(f"Refinement: {initial_rmsd - final_rmsd:.4f} Å") - - if do_relax: - print("\n" + "="*60) - print("Running Force Field Relaxation...") - print("="*60) - - relax_engine = RelaxationEngine( - clash_weight=1.0, - dg_weight=1.0, - restraint_weight=20.0, - lr=1.0, - max_steps=100 - ) - - relaxed_coords, metrics = relax_engine.relax( - ligand_coords=current_coords, - protein_batch=protein_batch, - distance_bounds=batch.get('distance_bounds'), - device=device - ) - - relaxed_rmsd = torch.sqrt(torch.mean((relaxed_coords - ligand_coords_x1) ** 2)).item() - print(f"Relaxed RMSD: {relaxed_rmsd:.4f} Å") - print(f"Improvement: {final_rmsd - relaxed_rmsd:.4f} Å") - print(f"Metrics: {metrics}") - print("-" * 60) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--ckpt', type=str, required=True) - parser.add_argument('--config', type=str, required=True) - parser.add_argument('--split', type=str, default='val') - parser.add_argument('--pdb', type=str, default=None) - parser.add_argument('--relax', action='store_true', help='Run force field relaxation') - args = parser.parse_args() - - test_checkpoint(args.ckpt, args.config, args.split, args.pdb, args.relax) diff --git a/scripts/debug_nan.py b/scripts/debug_nan.py deleted file mode 100755 index 7cfbe96..0000000 --- a/scripts/debug_nan.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python -""" -Debug script to identify NaN source in validation. -Tests model forward pass step by step. -""" - -import torch -import yaml -import sys -sys.path.insert(0, '/home/jaemin/project/protein-ligand/pose-refine') - -from src.utils.model_builder import build_model -from src.data.dataset import FlowFixDataset, collate_flowfix_batch -from torch.utils.data import DataLoader - -def check_tensor(name, t): - """Check tensor for NaN/Inf and print stats.""" - if t is None: - print(f" {name}: None") - return - nan_count = torch.isnan(t).sum().item() - inf_count = torch.isinf(t).sum().item() - print(f" {name}: shape={tuple(t.shape)}, nan={nan_count}, inf={inf_count}, " - f"min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}") - -def main(): - # Load config - config_path = '/home/jaemin/project/protein-ligand/pose-refine/save/overfit-test-32/config.yaml' - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - - device = torch.device('cuda:0') - - # Load dataset (validation mode - uses x0 positions, not x_t) - val_dataset = FlowFixDataset( - data_dir=config['data']['data_dir'], - split_file=config['data']['split_file'], - split='val', - fix_pose=config['data'].get('fix_pose', True), - fix_pose_high_rmsd=config['data'].get('fix_pose_high_rmsd', False), - position_noise_scale=0.0, # No noise for validation - loading_mode='lazy', - max_samples=1, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_flowfix_batch, - ) - - # Build model - model = build_model(config['model'], device) - - # Load checkpoint - checkpoint_path = '/home/jaemin/project/protein-ligand/pose-refine/save/overfit-test-32/checkpoints/latest.pt' - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - model.load_state_dict(checkpoint['model_state_dict']) - - print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") - - # Test both eval modes: standard eval and eval with BatchNorm in train mode - for bn_mode in ['eval', 'train']: - print(f"\n{'='*60}") - print(f"TESTING WITH BATCHNORM IN {bn_mode.upper()} MODE") - print(f"{'='*60}") - - model.eval() - - # Optionally keep BatchNorm in train mode - if bn_mode == 'train': - for module in model.modules(): - cls_name = type(module).__name__ - if 'BatchNorm' in cls_name: - module.train() - - test_model(model, val_loader, device) - - -def test_model(model, val_loader, device): - """Test model forward pass and ODE integration with detailed metrics.""" - - # Get one batch - batch = next(iter(val_loader)) - - # Move to device - ligand_batch = batch['ligand_graph'].to(device) - protein_batch = batch['protein_graph'].to(device) - ligand_coords_x0 = batch['ligand_coords_x0'].to(device) - ligand_coords_x1 = batch['ligand_coords_x1'].to(device) - - print(f"\n=== Input Data Check ===") - print(f"PDB: {batch['pdb_ids'][0]}") - check_tensor("ligand_coords_x0", ligand_coords_x0) - check_tensor("ligand_coords_x1", ligand_coords_x1) - - # Initial RMSD - initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1) ** 2)).item() - print(f"Initial RMSD: {initial_rmsd:.4f} Å") - - # Simulate ODE integration with detailed debugging - print(f"\n=== Detailed ODE Integration Debug (Euler, 20 steps) ===") - print(f"{'Step':>4} | {'t':>5} | {'RMSD':>7} | {'|v|':>7} | {'|v_tgt|':>7} | {'Sim':>7}") - print("-" * 60) - - num_steps = 20 - current_coords = ligand_coords_x0.clone() - - for step in range(num_steps): - t_current = step / num_steps - t_next = (step + 1) / num_steps - dt = t_next - t_current - - t = torch.tensor([t_current], device=device) - ligand_batch.pos = current_coords.clone() - - # Self-conditioning (if enabled in model, usually x1_self_cond) - # Note: In actual validation, we might need to handle self_conditioning properly - # but for simple velocity direction check, t=0 is most critical. - - with torch.no_grad(): - # In validation, we might use x_t directly for self_cond - velocity = model(protein_batch, ligand_batch, t) - - # Ideal velocity at this point to reach x1 - # In CFM with linear interpolation, the constant velocity is (x1 - x0) - # However, our target direction at current point is (x1 - xt) - v_target = ligand_coords_x1 - current_coords - - # Metrics - rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item() - v_norm = torch.norm(velocity, dim=-1).mean().item() - v_target_norm = torch.norm(v_target, dim=-1).mean().item() - - # Cosine similarity between predicted velocity and target direction - v_flat = velocity.view(-1) - target_flat = v_target.view(-1) - sim = torch.nn.functional.cosine_similarity(v_flat, target_flat, dim=0).item() - - print(f"{step:4d} | {t_current:5.2f} | {rmsd:7.4f} | {v_norm:7.4f} | {v_target_norm:7.4f} | {sim:7.4f}") - - # Clip velocity (same as in validation) - velocity_norm = torch.norm(velocity, dim=-1, keepdim=True) - max_velocity = 5.0 - scale = torch.clamp(max_velocity / (velocity_norm + 1e-8), max=1.0) - velocity = velocity * scale - - # Update coords - current_coords = current_coords + dt * velocity - - if torch.isnan(current_coords).any(): - print(f"!!! Error: NaN detected at step {step}") - break - - final_rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item() - print("-" * 45) - print(f"Final RMSD: {final_rmsd:.4f} Å") - print(f"Refinement: {initial_rmsd - final_rmsd:+.4f} Å") - -if __name__ == '__main__': - main() diff --git a/scripts/debug_velocity.py b/scripts/debug_velocity.py deleted file mode 100755 index c3daf39..0000000 --- a/scripts/debug_velocity.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python -""" -Debug script to check model velocity predictions. -""" -import os -import sys -import torch -import yaml -import numpy as np - -# Add project root to path -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from src.data.dataset import FlowFixDataset, collate_flowfix_batch -from src.utils.model_builder import build_model - - -def main(): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Device: {device}") - - # Load checkpoint from overfit-test-8 - checkpoint_path = 'save/overfit-test-8/checkpoints/latest.pt' - if not os.path.exists(checkpoint_path): - print(f"No checkpoint found at {checkpoint_path}") - return - - print(f"Loading checkpoint: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - - # Load config from checkpoint if available, otherwise from file - if 'config' in checkpoint: - config = checkpoint['config'] - print("Using config from checkpoint") - else: - config_path = 'configs/train_joint.yaml' - with open(config_path) as f: - config = yaml.safe_load(f) - print(f"Using config from {config_path}") - - # Build model - model = build_model(config['model'], device) - - # Load weights - state_dict = checkpoint.get('model_state_dict', checkpoint) - model.load_state_dict(state_dict, strict=False) - model = model.to(device) - model.eval() - - print(f"Model loaded, epoch {checkpoint.get('epoch', 'unknown')}") - - # Create dataset - dataset = FlowFixDataset( - data_dir='train_data', - split_file='train_data/splits_overfit_tiny.json', - split='train', # Use train split to get fixed pose - seed=42, - fix_pose=True, # Use pose 1 (easier) - ) - - # Get one sample - sample = dataset[0] - batch = collate_flowfix_batch([sample]) - - pdb_id = batch['pdb_ids'][0] - protein_batch = batch['protein_graph'].to(device) - ligand_batch = batch['ligand_graph'].to(device) - ligand_coords_x0 = batch['ligand_coords_x0'].to(device) - ligand_coords_x1 = batch['ligand_coords_x1'].to(device) - t = batch['t'].to(device) - - print(f"\nPDB: {pdb_id}") - print(f"Num ligand atoms: {ligand_coords_x0.shape[0]}") - print(f"t: {t.item():.4f}") - - # True velocity - true_velocity = ligand_coords_x1 - ligand_coords_x0 - true_velocity_norm = torch.norm(true_velocity, dim=-1).mean() - - print(f"\n=== True Velocity ===") - print(f" Mean norm: {true_velocity_norm.item():.4f} Å") - print(f" First 5 atoms:\n{true_velocity[:5]}") - - # Predicted velocity at t=0 - with torch.no_grad(): - ligand_batch.pos = ligand_coords_x0.clone() - t_zero = torch.zeros(1, device=device) - pred_velocity_t0 = model(protein_batch, ligand_batch, t_zero) - - pred_velocity_t0_norm = torch.norm(pred_velocity_t0, dim=-1).mean() - - print(f"\n=== Predicted Velocity (t=0) ===") - print(f" Mean norm: {pred_velocity_t0_norm.item():.4f} Å") - print(f" First 5 atoms:\n{pred_velocity_t0[:5]}") - - # Cosine similarity - cos_sim = torch.nn.functional.cosine_similarity( - pred_velocity_t0.flatten(), - true_velocity.flatten(), - dim=0 - ) - print(f"\n=== Comparison ===") - print(f" Cosine similarity: {cos_sim.item():.4f}") - print(f" Pred/True norm ratio: {pred_velocity_t0_norm.item() / true_velocity_norm.item():.4f}") - - # Check if velocity is near zero - if pred_velocity_t0_norm.item() < 0.1: - print(f"\n⚠️ WARNING: Predicted velocity is near zero!") - - # Predicted velocity at various t - print(f"\n=== Velocity at different timesteps ===") - for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]: - t_test = torch.tensor([t_val], device=device) - x_t = (1 - t_val) * ligand_coords_x0 + t_val * ligand_coords_x1 - ligand_batch.pos = x_t.clone() - - with torch.no_grad(): - pred_v = model(protein_batch, ligand_batch, t_test) - - pred_norm = torch.norm(pred_v, dim=-1).mean() - cos = torch.nn.functional.cosine_similarity( - pred_v.flatten(), - true_velocity.flatten(), - dim=0 - ) - print(f" t={t_val:.2f}: ||v||={pred_norm.item():.4f}, cos_sim={cos.item():.4f}") - - # Initial and final RMSD - initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1)**2)) - print(f"\n=== RMSD ===") - print(f" Initial (docked vs crystal): {initial_rmsd.item():.4f} Å") - - # Simulate one step - dt = 0.05 - new_coords = ligand_coords_x0 + dt * pred_velocity_t0 - new_rmsd = torch.sqrt(torch.mean((new_coords - ligand_coords_x1)**2)) - print(f" After one step (dt={dt}): {new_rmsd.item():.4f} Å") - - # Check for NaN/Inf in velocity - print(f"\n=== NaN/Inf Check ===") - print(f" Velocity has NaN: {torch.isnan(pred_velocity_t0).any().item()}") - print(f" Velocity has Inf: {torch.isinf(pred_velocity_t0).any().item()}") - print(f" Velocity max: {pred_velocity_t0.abs().max().item():.4f}") - - # Full ODE simulation - print(f"\n=== Full ODE Simulation (10 steps) ===") - current = ligand_coords_x0.clone() - for step in range(10): - t_val = step / 10 - t_test = torch.tensor([t_val], device=device) - ligand_batch.pos = current.clone() - with torch.no_grad(): - v = model(protein_batch, ligand_batch, t_test) - dt = 0.1 - current = current + dt * v - rmsd = torch.sqrt(torch.mean((current - ligand_coords_x1)**2)).item() - v_norm = torch.norm(v, dim=-1).mean().item() - v_max = torch.abs(v).max().item() - has_nan = torch.isnan(current).any().item() - print(f" Step {step}: t={t_val:.2f}, RMSD={rmsd:.4f}A, ||v||={v_norm:.4f}, max|v|={v_max:.4f}, NaN={has_nan}") - if has_nan: - print(" NaN detected, stopping") - break - - # Test at t=0 vs t=0.5 to see if model generalizes - print(f"\n=== Model behavior at different timesteps (same position) ===") - ligand_batch.pos = ligand_coords_x0.clone() - for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]: - t_test = torch.tensor([t_val], device=device) - with torch.no_grad(): - v = model(protein_batch, ligand_batch, t_test) - v_norm = torch.norm(v, dim=-1).mean().item() - v_max = torch.abs(v).max().item() - cos_sim = torch.nn.functional.cosine_similarity(v.flatten(), true_velocity.flatten(), dim=0).item() - print(f" t={t_val:.2f}: ||v||={v_norm:.4f}, max|v|={v_max:.4f}, cos_sim={cos_sim:.4f}") - - # Test at EXACT training position x_t - print(f"\n=== Test at EXACT training position x_t (should match training) ===") - for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]: - # Compute x_t = (1-t)*x0 + t*x1 - x_t = (1 - t_val) * ligand_coords_x0 + t_val * ligand_coords_x1 - ligand_batch.pos = x_t.clone() - t_test = torch.tensor([t_val], device=device) - with torch.no_grad(): - v = model(protein_batch, ligand_batch, t_test) - v_norm = torch.norm(v, dim=-1).mean().item() - cos_sim = torch.nn.functional.cosine_similarity(v.flatten(), true_velocity.flatten(), dim=0).item() - mse = torch.mean((v - true_velocity)**2).item() - print(f" t={t_val:.2f}: ||v||={v_norm:.4f}, cos_sim={cos_sim:.4f}, MSE={mse:.4f}") - - -if __name__ == '__main__': - main() diff --git a/scripts/slurm/run_debug.sh b/scripts/slurm/run_debug.sh deleted file mode 100755 index 47db03d..0000000 --- a/scripts/slurm/run_debug.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=debug-velocity -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=4 -#SBATCH --mem=32G -#SBATCH --time=00:10:00 -#SBATCH --output=logs/debug_%j.out -#SBATCH --error=logs/debug_%j.out -#SBATCH --exclude=gpu3 - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine -PYTHONPATH=/home/jaemin/project/protein-ligand/pose-refine python scripts/debug_velocity.py diff --git a/scripts/slurm/run_debug_ckpts.sh b/scripts/slurm/run_debug_ckpts.sh deleted file mode 100755 index ae40600..0000000 --- a/scripts/slurm/run_debug_ckpts.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=debug-train -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=32G -#SBATCH --time=00:30:00 -#SBATCH --output=logs/debug_train_%j.out -#SBATCH --error=logs/debug_train_%j.out - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 -export PYTHONPATH=$PYTHONPATH:. - -echo "============================================================" -echo "Analyzing Job 203 (32-Sample Fixed) - Training Sample 10gs" -echo "============================================================" -python scripts/debug_ckpt.py --ckpt save/overfit-test-32-fixed/checkpoints/latest.pt --config configs/train_overfit_32.yaml --split train --pdb 10gs - -echo -e "\n\n============================================================" -echo "Analyzing Job 226 (Rectified Flow) - Training Sample 10gs (Epoch 20)" -echo "============================================================" -python scripts/debug_ckpt.py --ckpt save/overfit-fast-sink/checkpoints/epoch_0020.pt --config configs/train_rectified_flow.yaml --split train --pdb 10gs diff --git a/scripts/slurm/run_debug_detailed.sh b/scripts/slurm/run_debug_detailed.sh deleted file mode 100755 index 1447d20..0000000 --- a/scripts/slurm/run_debug_detailed.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=debug-ode -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=32G -#SBATCH --time=01:00:00 -#SBATCH --output=logs/debug_ode_%j.out -#SBATCH --error=logs/debug_ode_%j.out - -# Detailed ODE analysis script -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -python scripts/debug_nan.py diff --git a/scripts/slurm/run_debug_nan.sh b/scripts/slurm/run_debug_nan.sh deleted file mode 100755 index 8eb6313..0000000 --- a/scripts/slurm/run_debug_nan.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=debug-nan -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=4 -#SBATCH --mem=32G -#SBATCH --time=00:30:00 -#SBATCH --output=logs/debug_nan_%j.out -#SBATCH --error=logs/debug_nan_%j.out -#SBATCH --exclude=gpu3 - -# Debug NaN issue in validation -# Tests the BatchNorm fix by running validation on checkpoint - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run debug script -python scripts/debug_nan.py diff --git a/scripts/slurm/run_debug_relax.sh b/scripts/slurm/run_debug_relax.sh deleted file mode 100755 index 193beda..0000000 --- a/scripts/slurm/run_debug_relax.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=debug-relax -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=4 -#SBATCH --mem=32G -#SBATCH --time=00:30:00 -#SBATCH --output=logs/debug_relax_%j.out -#SBATCH --error=logs/debug_relax_%j.out - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine -PYTHONPATH=/home/jaemin/project/protein-ligand/pose-refine python scripts/debug_ckpt.py --ckpt save/overfit-fast-sink/checkpoints/latest.pt --config configs/train_rectified_flow.yaml --split train --pdb 10gs --relax diff --git a/scripts/slurm/run_fast_overfit.sh b/scripts/slurm/run_fast_overfit.sh deleted file mode 100755 index 546cac0..0000000 --- a/scripts/slurm/run_fast_overfit.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=overfit-fast-sink -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=04:00:00 -#SBATCH --output=logs/overfit_fast_sink_%j.out -#SBATCH --error=logs/overfit_fast_sink_%j.out -#SBATCH --exclude=gpu3 - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run training -python train.py --config configs/train_fast_overfit.yaml diff --git a/scripts/slurm/run_full_val_inference.sh b/scripts/slurm/run_full_val_inference.sh index 66789d2..deb6e37 100755 --- a/scripts/slurm/run_full_val_inference.sh +++ b/scripts/slurm/run_full_val_inference.sh @@ -42,7 +42,7 @@ fi echo "Starting full validation inference (200 val PDBs, all poses)..." $PYTHON -u scripts/analysis/infer_full_validation.py \ - --config configs/train_joint.yaml \ + --config configs/train.yaml \ --checkpoint "$CHECKPOINT" \ --output "$OUTPUT" diff --git a/scripts/slurm/run_overfit_32.sh b/scripts/slurm/run_overfit_32.sh deleted file mode 100755 index 78b2b57..0000000 --- a/scripts/slurm/run_overfit_32.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=overfit-32 -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=04:00:00 -#SBATCH --output=logs/overfit_32_%j.out -#SBATCH --error=logs/overfit_32_%j.out -#SBATCH --exclude=gpu3 - -# Overfit test with 32 samples -# More data for stable BatchNorm stats, val=subset of train - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run training -python train.py --config configs/train_overfit_32.yaml diff --git a/scripts/slurm/run_overfit_32_fixed.sh b/scripts/slurm/run_overfit_32_fixed.sh deleted file mode 100755 index 7c4c9ad..0000000 --- a/scripts/slurm/run_overfit_32_fixed.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=overfit-32-fix -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=04:00:00 -#SBATCH --output=logs/overfit_32_fix_%j.out -#SBATCH --error=logs/overfit_32_fix_%j.out -#SBATCH --exclude=gpu3 - -# Overfit test with 32 samples -# More data for stable BatchNorm stats, val=subset of train - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run training -python train.py --config configs/train_overfit_32.yaml diff --git a/scripts/slurm/run_overfit_high_rmsd.sh b/scripts/slurm/run_overfit_high_rmsd.sh deleted file mode 100755 index a1f0ef0..0000000 --- a/scripts/slurm/run_overfit_high_rmsd.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=overfit-high-rmsd -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=04:00:00 -#SBATCH --output=logs/overfit_high_rmsd_%j.out -#SBATCH --error=logs/overfit_high_rmsd_%j.out -#SBATCH --exclude=gpu3 - -# Overfit test with high RMSD poses (32 samples) -# Testing if the model can handle large refinements - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run training -python train.py --config configs/train_overfit_high_rmsd.yaml diff --git a/scripts/slurm/run_overfit_quick.sh b/scripts/slurm/run_overfit_quick.sh deleted file mode 100755 index 9015f4a..0000000 --- a/scripts/slurm/run_overfit_quick.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=overfit-quick -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=01:00:00 -#SBATCH --output=logs/overfit_quick_%j.out -#SBATCH --error=logs/overfit_quick_%j.out -#SBATCH --exclude=gpu3 - -# Quick overfit test to verify BatchNorm fix -# Runs shorter training with more frequent validation to test NaN fix - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Run with quick validation (validate every 50 epochs instead of 100) -python train.py \ - --config configs/train_overfit_test.yaml \ - --name overfit-quick-batchnorm-fix \ - --training.num_epochs 200 \ - --training.validation.frequency 50 diff --git a/scripts/slurm/run_overfit_test.sh b/scripts/slurm/run_overfit_test.sh deleted file mode 100755 index ddf78f9..0000000 --- a/scripts/slurm/run_overfit_test.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=flowfix-overfit -#SBATCH --partition=6000ada -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=64G -#SBATCH --time=04:00:00 -#SBATCH --output=logs/slurm_%j.out -#SBATCH --error=logs/slurm_%j.out -#SBATCH --exclude=gpu3 - -# Overfit test: single GPU, small dataset - -source /home/jaemin/miniforge3/etc/profile.d/conda.sh -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Enable TF32 -export TORCH_ALLOW_TF32_CUBLAS=1 - -# Single GPU training -python train.py --config configs/train_overfit_test.yaml diff --git a/scripts/slurm/run_test_quick.sh b/scripts/slurm/run_test_quick.sh deleted file mode 100755 index 942126d..0000000 --- a/scripts/slurm/run_test_quick.sh +++ /dev/null @@ -1,106 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=flowfix_test -#SBATCH --partition=6000ada -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=4 -#SBATCH --mem=32G -#SBATCH --time=00:30:00 -#SBATCH --output=logs/slurm_test_%j.out -#SBATCH --error=logs/slurm_test_%j.out - -source ~/.bashrc -conda activate torch-2.8 - -cd /home/jaemin/project/protein-ligand/pose-refine - -# Quick test with small data -python -c " -import torch -from src.data.dataset import FlowFixDataset, collate_flowfix_batch -from src.models.flowmatching import ProteinLigandFlowMatchingJoint -from src.utils.model_builder import build_model -from torch.utils.data import DataLoader -import yaml - -print('='*60) -print('Testing Dataset + Model Integration') -print('='*60) - -# Load config -with open('configs/train_joint.yaml') as f: - config = yaml.safe_load(f) - -# Create small dataset -print('\\n1. Creating dataset...') -ds = FlowFixDataset( - data_dir='train_data', - split='train', - max_samples=4, - cross_edge_cutoff=6.0, - cross_edge_max_neighbors=16, - intra_edge_cutoff=6.0, - intra_edge_max_neighbors=16, -) -ds.set_epoch(0) -print(f' Dataset size: {len(ds)}') - -# Create dataloader -print('\\n2. Creating dataloader...') -loader = DataLoader(ds, batch_size=2, collate_fn=collate_flowfix_batch, num_workers=0) -batch = next(iter(loader)) -print(f' Batch keys: {list(batch.keys())}') -print(f' t shape: {batch[\"t\"].shape}') -print(f' cross_edge_index shape: {batch[\"cross_edge_index\"].shape}') -print(f' intra_edge_index shape: {batch[\"intra_edge_index\"].shape}') - -# Create model -print('\\n3. Building model...') -model = build_model(config['model']).cuda() -print(f' Model type: {type(model).__name__}') -print(f' Parameters: {sum(p.numel() for p in model.parameters()):,}') - -# Move batch to GPU -print('\\n4. Moving batch to GPU...') -protein_batch = batch['protein_graph'].to('cuda') -ligand_batch = batch['ligand_graph'].to('cuda') -t = batch['t'].to('cuda') -cross_edge_index = batch['cross_edge_index'].to('cuda') -intra_edge_index = batch['intra_edge_index'].to('cuda') -ligand_coords_x0 = batch['ligand_coords_x0'].to('cuda') -ligand_coords_x1 = batch['ligand_coords_x1'].to('cuda') - -# Forward pass -print('\\n5. Forward pass...') -velocity = model(protein_batch, ligand_batch, t, cross_edge_index, intra_edge_index) -print(f' Velocity shape: {velocity.shape}') -print(f' Expected: {ligand_batch.pos.shape}') - -# Self-conditioning test -print('\\n6. Self-conditioning test...') -t_expanded = t[ligand_batch.batch].unsqueeze(-1) -x1_self_cond = ligand_batch.pos + (1 - t_expanded) * velocity -velocity2 = model(protein_batch, ligand_batch, t, cross_edge_index, intra_edge_index, x1_self_cond=x1_self_cond) -print(f' Velocity2 shape: {velocity2.shape}') - -# Loss computation -print('\\n7. Loss computation...') -true_velocity = ligand_coords_x1 - ligand_coords_x0 -loss = torch.mean((velocity - true_velocity) ** 2) -print(f' Loss: {loss.item():.6f}') - -# Backward pass -print('\\n8. Backward pass...') -loss.backward() -print(' Backward: OK') - -# Validation mode (no edges passed - should use fallback) -print('\\n9. Validation mode (edge fallback)...') -model.eval() -with torch.no_grad(): - velocity_val = model(protein_batch, ligand_batch, t) # No edges - uses fallback -print(f' Velocity shape: {velocity_val.shape}') - -print('\\n' + '='*60) -print('All tests PASSED!') -print('='*60) -" diff --git a/scripts/slurm/run_train_joint.sh b/scripts/slurm/run_train_joint.sh deleted file mode 100755 index fb6905c..0000000 --- a/scripts/slurm/run_train_joint.sh +++ /dev/null @@ -1,95 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=flowfix_joint -#SBATCH --output=logs/slurm_%j.out -#SBATCH --partition=6000ada # GPU partition -#SBATCH --nodes=1 # Single node -#SBATCH --ntasks-per-node=1 # One task per node -#SBATCH --cpus-per-task=32 # CPUs for data loading -#SBATCH --gres=gpu:8 # Request 8 GPUs -#SBATCH --mem=0 # Request all available memory -#SBATCH --time=30-00:00:00 # 30 days time limit -#SBATCH --exclude=gpu3 # Exclude gpu3 (driver mismatch) - -# Python path -PYTHON=/home/jaemin/miniforge3/envs/torch-2.8/bin/python - -# Project directory -PROJECT_DIR=/home/jaemin/project/protein-ligand/pose-refine - -# Resume checkpoint (set this to resume training, leave empty for fresh start) -CHECKPOINT_PATH="" - -# Print job info -echo "==========================================" -echo "SLURM Job ID: $SLURM_JOB_ID" -echo "Node: $SLURM_NODELIST" -echo "Start time: $(date)" -echo "Working directory: $PROJECT_DIR" -echo "Config: configs/train_joint.yaml" -echo "Number of GPUs: $SLURM_GPUS_ON_NODE" -echo "==========================================" - -# Print GPU info -nvidia-smi - -# Check Python and CUDA -echo "" -echo "Python: $PYTHON" -echo "Python version: $($PYTHON --version)" -echo "PyTorch version: $($PYTHON -c 'import torch; print(torch.__version__)')" -echo "CUDA available: $($PYTHON -c 'import torch; print(torch.cuda.is_available())')" -echo "CUDA version: $($PYTHON -c 'import torch; print(torch.version.cuda)')" -echo "Number of CUDA devices: $($PYTHON -c 'import torch; print(torch.cuda.device_count())')" -echo "" - -# Create logs directory if it doesn't exist -mkdir -p "$PROJECT_DIR/logs" - -# Change to project directory -cd "$PROJECT_DIR" - -# Disable Python output buffering for real-time logs -export PYTHONUNBUFFERED=1 - -# Set NCCL environment variables -export NCCL_DEBUG=WARN -export NCCL_P2P_DISABLE=1 -export NCCL_IB_DISABLE=1 - -# Set number of GPUs -if [ -n "$SLURM_GPUS_ON_NODE" ]; then - NUM_GPUS=$SLURM_GPUS_ON_NODE -elif [ -n "$SLURM_GPUS_PER_NODE" ]; then - NUM_GPUS=$SLURM_GPUS_PER_NODE -elif [ -n "$SLURM_JOB_GPUS" ]; then - NUM_GPUS=$(echo $SLURM_JOB_GPUS | awk -F',' '{print NF}') -else - NUM_GPUS=8 - echo "Warning: SLURM GPU variables not found, using default NUM_GPUS=8" -fi - -# Build resume argument if checkpoint path is set -RESUME_ARG="" -if [ -n "$CHECKPOINT_PATH" ]; then - RESUME_ARG="--resume $CHECKPOINT_PATH" - echo "Resuming from checkpoint: $CHECKPOINT_PATH" -fi - -# Run distributed training using torch.distributed.run -echo "Starting multi-GPU training with $NUM_GPUS GPUs..." -echo "Command: python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS train.py --config configs/train_joint.yaml $RESUME_ARG" -echo "" - -$PYTHON -m torch.distributed.run \ - --standalone \ - --nnodes=1 \ - --nproc_per_node=$NUM_GPUS \ - train.py \ - --config configs/train_joint.yaml \ - $RESUME_ARG - -# Print end time -echo "" -echo "==========================================" -echo "End time: $(date)" -echo "==========================================" diff --git a/scripts/slurm/run_train_joint_test.sh b/scripts/slurm/run_train_joint_test.sh deleted file mode 100755 index 952272a..0000000 --- a/scripts/slurm/run_train_joint_test.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=flowfix_test -#SBATCH --output=logs/slurm_%j.out -#SBATCH --partition=test # Test partition (gpu2, 2h limit) -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=16 -#SBATCH --gres=gpu:a5000:1 # 1 A5000 GPU -#SBATCH --mem=32G - -# Python path -PYTHON=/home/jaemin/miniforge3/envs/torch-2.8/bin/python - -# Project directory -PROJECT_DIR=/home/jaemin/project/protein-ligand/pose-refine - -# Print job info -echo "==========================================" -echo "SLURM Job ID: $SLURM_JOB_ID" -echo "Node: $SLURM_NODELIST" -echo "Start time: $(date)" -echo "Config: configs/train_joint_test.yaml" -echo "==========================================" - -nvidia-smi - -echo "" -echo "Python: $PYTHON" -$PYTHON --version -$PYTHON -c "import torch; print(f'PyTorch {torch.__version__}, CUDA avail: {torch.cuda.is_available()}, devices: {torch.cuda.device_count()}')" -echo "" - -mkdir -p "$PROJECT_DIR/logs" -cd "$PROJECT_DIR" - -export PYTHONUNBUFFERED=1 - -# Single-GPU test run (no distributed) -echo "Starting single-GPU test run..." -$PYTHON train.py --config configs/train_joint_test.yaml - -echo "" -echo "==========================================" -echo "End time: $(date)" -echo "==========================================" diff --git a/scripts/slurm/run_visualize_trajectory.sh b/scripts/slurm/run_visualize_trajectory.sh deleted file mode 100755 index 676ad80..0000000 --- a/scripts/slurm/run_visualize_trajectory.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=flowfix_viz -#SBATCH --output=logs/viz_%j.out -#SBATCH --error=logs/viz_%j.err -#SBATCH --partition=g4090_short,6000ada_short,h100_short -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=8 -#SBATCH --gres=gpu:1 - -echo "==============================================" -echo "FlowFix Trajectory Visualization" -echo "==============================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $SLURM_NODELIST" -echo "Start time: $(date)" -echo "==============================================" - -cd /home/sim/project/flowfix -mkdir -p logs - -PYTHON=/home/sim/miniconda3/envs/protein-ligand/bin/python - -echo "" -echo "Running trajectory visualization..." -$PYTHON -u scripts/analysis/visualize_trajectory.py - -echo "" -echo "==============================================" -echo "Visualization completed!" -echo "End time: $(date)" -echo "==============================================" diff --git a/src/data/dataset_torsion.py b/src/data/dataset_torsion.py new file mode 100644 index 0000000..9a55a1a --- /dev/null +++ b/src/data/dataset_torsion.py @@ -0,0 +1,206 @@ +""" +Torsion-aware dataset and collation for SE(3) + Torsion decomposition training. + +FlowFixTorsionDataset extends FlowFixDataset to compute and return +torsion decomposition data (translation, rotation, torsion_changes, mask_rotate). +""" + +import torch +import numpy as np +from typing import Optional, List, Dict, Any + +from .dataset import FlowFixDataset, collate_flowfix_batch + + +class FlowFixTorsionDataset(FlowFixDataset): + """ + FlowFixDataset with SE(3) + Torsion decomposition. + + Returns additional torsion_data dict with: + - translation [3], rotation [3] + - torsion_changes [M], rotatable_edges [M, 2], mask_rotate [M, N] + """ + + def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]: + sample = super().__getitem__(idx) + if sample is None: + return None + + # Compute torsion decomposition + pdb_id = self.pdb_ids[idx] + pdb_dir = self.data_dir / pdb_id + + # Reload ligand_data to access edges (parent only returns coords) + if self.loading_mode == "preload": + ligands_list = self.preloaded_data[pdb_id]['ligands'] + elif self.loading_mode == "hybrid": + ligands_list = torch.load(pdb_dir / "ligands.pt", weights_only=False) + else: + ligands_list = torch.load(pdb_dir / "ligands.pt", weights_only=False) + + # Use same pose index as parent (deterministic via epoch + idx) + rng = np.random.RandomState(self.seed + self.epoch * 10000 + idx) + pose_idx = rng.randint(0, len(ligands_list)) + ligand_data = ligands_list[pose_idx] + + torsion_data = _compute_torsion_data( + ligand_data, + sample['ligand_coords_x0'], + sample['ligand_coords_x1'], + ) + sample['torsion_data'] = torsion_data + return sample + + +def _compute_torsion_data( + ligand_data: dict, + coords_x0: torch.Tensor, + coords_x1: torch.Tensor, +) -> Optional[Dict[str, torch.Tensor]]: + """ + Compute SE(3) + Torsion decomposition from ligand data. + + Returns dict with translation, rotation, torsion_changes, rotatable_edges, mask_rotate. + """ + # Use pre-computed if available + if 'torsion_changes' in ligand_data and 'mask_rotate' in ligand_data: + return { + 'translation': ligand_data.get('translation', torch.zeros(3)), + 'rotation': ligand_data.get('rotation', torch.zeros(3)), + 'torsion_changes': ligand_data['torsion_changes'], + 'rotatable_edges': ligand_data.get('rotatable_edges', torch.zeros(0, 2, dtype=torch.long)), + 'mask_rotate': ligand_data['mask_rotate'], + } + + # Compute on-the-fly + try: + from src.data.ligand_feat import compute_rigid_transform, get_transformation_mask + + translation, rotation = compute_rigid_transform(coords_x0, coords_x1) + + # Get edges + edges = None + if 'edges' in ligand_data: + edges = ligand_data['edges'] + elif 'edge' in ligand_data and 'edges' in ligand_data['edge']: + edges = ligand_data['edge']['edges'] + + n_atoms = coords_x0.shape[0] + empty_result = { + 'translation': translation, + 'rotation': rotation, + 'torsion_changes': torch.zeros(0), + 'rotatable_edges': torch.zeros(0, 2, dtype=torch.long), + 'mask_rotate': torch.zeros(0, n_atoms, dtype=torch.bool), + } + + if edges is None: + return empty_result + + mask_rotate, rotatable_edge_indices = get_transformation_mask(edges, n_atoms) + + if len(rotatable_edge_indices) == 0: + return empty_result + + edges_np = edges.numpy() if torch.is_tensor(edges) else edges + rot_edges = torch.tensor( + [[int(edges_np[0, i]), int(edges_np[1, i])] for i in rotatable_edge_indices], + dtype=torch.long, + ) + mask_rot_filtered = mask_rotate[rotatable_edge_indices] + + return { + 'translation': translation, + 'rotation': rotation, + 'torsion_changes': torch.zeros(len(rotatable_edge_indices)), + 'rotatable_edges': rot_edges, + 'mask_rotate': mask_rot_filtered, + } + + except Exception: + return None + + +def collate_torsion_batch(samples: List[Dict]) -> Dict[str, Any]: + """ + Collate batch with torsion data. + + Wraps collate_flowfix_batch and adds torsion_data collation. + """ + # Filter None + samples = [s for s in samples if s is not None] + if not samples: + raise ValueError("All samples in batch are None!") + + # Base collation (protein, ligand, coords, distance bounds) + batch = collate_flowfix_batch(samples) + + # Collate torsion data + torsion_list = [s.get('torsion_data', None) for s in samples] + batch['torsion_data'] = _collate_torsion_data(torsion_list, batch['ligand_graph']) + + return batch + + +def _collate_torsion_data( + torsion_data_list: List[Optional[Dict[str, torch.Tensor]]], + ligand_batch, +) -> Optional[Dict[str, torch.Tensor]]: + """ + Collate torsion data across batch samples. + + Concatenates variable-length rotatable bonds, adjusts atom indices + with batch offsets, and expands mask_rotate to full batch size. + """ + valid = [td for td in torsion_data_list if td is not None] + if not valid: + return None + + total_atoms = ligand_batch.num_nodes + atom_counts = torch.bincount(ligand_batch.batch) + atom_offsets = torch.cat([torch.zeros(1, dtype=torch.long), atom_counts.cumsum(0)[:-1]]) + + translations = [] + rotations = [] + torsion_changes = [] + rotatable_edges = [] + mask_rotate_list = [] + + for i, td in enumerate(torsion_data_list): + if td is None: + translations.append(torch.zeros(3)) + rotations.append(torch.zeros(3)) + continue + + translations.append(td['translation']) + rotations.append(td['rotation']) + + if td['torsion_changes'].numel() > 0: + torsion_changes.append(td['torsion_changes']) + + offset = atom_offsets[i].item() + edges = td['rotatable_edges'].clone() + offset + rotatable_edges.append(edges) + + # Expand mask to full batch atom count + m_i = td['mask_rotate'].shape[0] + n_i = td['mask_rotate'].shape[1] + full_mask = torch.zeros(m_i, total_atoms, dtype=torch.bool) + full_mask[:, offset:offset + n_i] = td['mask_rotate'] + mask_rotate_list.append(full_mask) + + result = { + 'translation': torch.stack(translations), + 'rotation': torch.stack(rotations), + } + + if torsion_changes: + result['torsion_changes'] = torch.cat(torsion_changes) + result['rotatable_edges'] = torch.cat(rotatable_edges) + result['mask_rotate'] = torch.cat(mask_rotate_list) + else: + result['torsion_changes'] = torch.zeros(0) + result['rotatable_edges'] = torch.zeros(0, 2, dtype=torch.long) + result['mask_rotate'] = torch.zeros(0, total_atoms, dtype=torch.bool) + + return result diff --git a/src/models/flowmatching_torsion.py b/src/models/flowmatching_torsion.py new file mode 100644 index 0000000..c8d1e6e --- /dev/null +++ b/src/models/flowmatching_torsion.py @@ -0,0 +1,182 @@ +""" +SE(3) + Torsion Decomposition Flow Matching Model. + +Instead of predicting per-atom Cartesian velocity [N, 3], this model predicts: +- Translation [B, 3]: CoM displacement +- Rotation [B, 3]: Axis-angle rotation around CoM +- Torsion [M]: One scalar per rotatable bond + +Inherits the encoder and interaction network from ProteinLigandFlowMatching, +replacing only the output heads. +""" + +import torch +import torch.nn as nn +import cuequivariance as cue_base +from torch_scatter import scatter_mean + +from .cue_layers import EquivariantMLP +from .torch_layers import MLP +from .flowmatching import ProteinLigandFlowMatching + + +class ProteinLigandFlowMatchingTorsion(ProteinLigandFlowMatching): + """ + SE(3) + Torsion decomposition variant of ProteinLigandFlowMatching. + + Shares encoder, interaction network, and velocity blocks with the base class. + Replaces the output heads with: + - Translation head: mean-pool → EquivariantMLP → [B, 3] + - Rotation head: mean-pool → EquivariantMLP → [B, 3] (axis-angle) + - Torsion head: src/dst node scalar concat → MLP → [M, 1] + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Get dimensions from parent + vel_hidden_scalar_dim = kwargs.get('velocity_hidden_scalar_dim', 128) + vel_hidden_vector_dim = kwargs.get('velocity_hidden_vector_dim', 16) + dropout = kwargs.get('dropout', 0.1) + + self._vel_hidden_scalar_dim = vel_hidden_scalar_dim + + vel_hidden_irreps = cue_base.Irreps( + "O3", + f"{vel_hidden_scalar_dim}x0e + {vel_hidden_vector_dim}x1o + {vel_hidden_vector_dim}x1e" + ) + intermediate_irreps = cue_base.Irreps( + "O3", + f"{vel_hidden_scalar_dim}x0e + {vel_hidden_vector_dim}x1o + {vel_hidden_vector_dim}x1e" + ) + vector_output_irreps = cue_base.Irreps("O3", "1x1o") + + # Translation head: pooled features → 3D vector + self.translation_output = EquivariantMLP( + irreps_in=vel_hidden_irreps, + irreps_hidden=intermediate_irreps, + irreps_out=vector_output_irreps, + num_layers=2, + dropout=dropout, + ) + + # Rotation head: pooled features → 3D axis-angle + self.rotation_output = EquivariantMLP( + irreps_in=vel_hidden_irreps, + irreps_hidden=intermediate_irreps, + irreps_out=vector_output_irreps, + num_layers=2, + dropout=dropout, + ) + + # Torsion head: src/dst node scalars → 1 scalar per rotatable bond + torsion_input_dim = vel_hidden_scalar_dim * 2 + self.torsion_output = MLP( + in_dim=torsion_input_dim, + hidden_dim=vel_hidden_scalar_dim, + out_dim=1, + num_layers=2, + activation='silu', + ) + + # Zero-initialize all output heads for stable training + self._zero_init_output_heads() + + # Learnable scales + self.translation_scale = nn.Parameter(torch.ones(1) * 0.1) + self.rotation_scale = nn.Parameter(torch.ones(1) * 0.1) + self.torsion_scale = nn.Parameter(torch.ones(1) * 0.1) + + def _zero_init_output_heads(self): + """Zero-initialize output heads for stable training start.""" + for head in [self.torsion_output]: + for module in head.layers: + if isinstance(module, nn.Linear): + nn.init.zeros_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def _encode(self, protein_batch, ligand_batch, t): + """ + Run shared encoder + interaction + velocity blocks. + + Returns h: [N_ligand, hidden_irreps] features after message passing. + """ + # ESM embeddings + if self.use_esm_embeddings: + protein_batch = self._integrate_esm_embeddings(protein_batch) + + protein_output = self.protein_network(protein_batch) + ligand_output = self.ligand_network(ligand_batch) + + # Interaction + (_, lig_out), (protein_context, _), _ = self.interaction_network( + protein_output, ligand_output, protein_batch, ligand_batch + ) + + # Conditioning + protein_context_expanded = protein_context[ligand_batch.batch] + combined_condition = torch.cat([protein_context_expanded, lig_out], dim=-1) + atom_condition = self.vel_atom_condition_proj(combined_condition) + + # Velocity blocks (shared backbone) + h = self.vel_input_projection(ligand_output) + h_initial = h + + for block in self.velocity_blocks: + h = block( + h, + ligand_batch.pos, + ligand_batch.edge_index, + ligand_batch.edge_attr, + condition=atom_condition, + ) + + h = h + h_initial + return h + + def forward( + self, + protein_batch, + ligand_batch, + t: torch.Tensor, + rotatable_edges: torch.Tensor = None, + ) -> dict: + """ + Predict SE(3) + Torsion velocity. + + Args: + protein_batch: Protein PyG batch + ligand_batch: Ligand PyG batch at time t + t: [B] timesteps + rotatable_edges: [M, 2] atom indices of rotatable bonds + + Returns: + Dict with 'translation' [B, 3], 'rotation' [B, 3], 'torsion' [M] + """ + h = self._encode(protein_batch, ligand_batch, t) + + # Translation: mean-pool → 3D + h_pooled = scatter_mean(h, ligand_batch.batch, dim=0) + translation = self.translation_output(h_pooled) * self.translation_scale + + # Rotation: mean-pool → 3D axis-angle + rotation = self.rotation_output(h_pooled) * self.rotation_scale + + # Torsion: src/dst scalar features → 1 scalar per bond + if rotatable_edges is not None and rotatable_edges.shape[0] > 0: + src_idx = rotatable_edges[:, 0] + dst_idx = rotatable_edges[:, 1] + + # Extract scalar part (first scalar_dim components of irreps) + h_scalar = h[:, :self._vel_hidden_scalar_dim] + edge_feat = torch.cat([h_scalar[src_idx], h_scalar[dst_idx]], dim=-1) + torsion = self.torsion_output(edge_feat).squeeze(-1) * self.torsion_scale + else: + torsion = torch.zeros(0, device=h.device) + + return { + 'translation': translation, + 'rotation': rotation, + 'torsion': torsion, + } diff --git a/src/utils/loss_utils.py b/src/utils/loss_utils.py deleted file mode 100644 index a24d98a..0000000 --- a/src/utils/loss_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Simplified loss utilities for FlowFix training. -Only velocity matching loss with logistic-normal time sampling. -""" - -import torch -import torch.distributions as dist - - -def compute_flow_matching_loss(pred_velocity, true_velocity, batch_indices): - """ - Simple uniform flow matching loss: pure MSE on velocity field. - - For linear interpolation, true velocity is constant across all t, - so all timesteps should be equally weighted. - - Args: - pred_velocity: Predicted velocity, shape [N, 3] - true_velocity: True velocity (x1 - x0), shape [N, 3] - batch_indices: Which batch each atom belongs to, shape [N] (unused but kept for compatibility) - - Returns: - loss: Uniform MSE loss - loss_dict: Dictionary with loss components - """ - # Simple uniform MSE loss (no per-sample normalization, no weighting) - total_loss = torch.nn.functional.mse_loss(pred_velocity, true_velocity) - - loss_dict = { - 'velocity_loss': total_loss.item(), - 'total_loss': total_loss.item() - } - - return total_loss, loss_dict - - -def sample_timesteps_logistic_normal(batch_size, device, mu=0.0, sigma=1.0, mix_ratio=0.5): - """ - Sample timesteps using logistic-normal distribution. - - Mixture of: - - Logistic-normal: sigma(Normal(mu, sigma)) where sigma is sigmoid - - Uniform: for diversity - - Args: - batch_size: Number of timesteps to sample - device: torch device - mu: Mean of underlying normal (logit space) - sigma: Std of underlying normal (logit space) - mix_ratio: Fraction to sample from logistic-normal (rest is uniform) - - Returns: - t: Sampled timesteps in [0, 1], shape [batch_size] - """ - n_logistic = int(batch_size * mix_ratio) - n_uniform = batch_size - n_logistic - - # Logistic-normal samples - normal_dist = dist.Normal(mu, sigma) - logit_samples = normal_dist.sample((n_logistic,)).to(device) - logistic_samples = torch.sigmoid(logit_samples) - - # Uniform samples - uniform_samples = torch.rand(n_uniform, device=device) - - # Concatenate and shuffle - t = torch.cat([logistic_samples, uniform_samples], dim=0) - t = t[torch.randperm(len(t), device=device)] - - return t diff --git a/src/utils/losses_torsion.py b/src/utils/losses_torsion.py new file mode 100644 index 0000000..5870a34 --- /dev/null +++ b/src/utils/losses_torsion.py @@ -0,0 +1,169 @@ +""" +Loss functions for SE(3) + Torsion decomposition training. +""" + +import torch +import torch.nn.functional as F + + +def compute_se3_torsion_loss( + pred: dict, + target: dict, + coords_x0: torch.Tensor, + coords_x1: torch.Tensor, + mask_rotate: torch.Tensor, + rotatable_edges: torch.Tensor, + batch_indices: torch.Tensor, + w_trans: float = 1.0, + w_rot: float = 1.0, + w_tor: float = 1.0, + w_coord: float = 0.5, +) -> dict: + """ + Compute SE(3) + Torsion decomposition loss. + + Args: + pred: Model output with 'translation' [B,3], 'rotation' [B,3], 'torsion' [M] + target: Target with 'translation' [B,3], 'rotation' [B,3], 'torsion_changes' [M] + coords_x0: Docked coordinates [N, 3] + coords_x1: Crystal coordinates [N, 3] + mask_rotate: [M, N] boolean mask + rotatable_edges: [M, 2] atom indices + batch_indices: [N] batch assignment + w_trans, w_rot, w_tor, w_coord: Component weights + + Returns: + Dict with 'total', 'translation', 'rotation', 'torsion', 'coord_recon' losses + """ + device = pred['translation'].device + + # Translation MSE + loss_trans = F.mse_loss(pred['translation'], target['translation'].to(device)) + + # Rotation MSE (axis-angle) + loss_rot = F.mse_loss(pred['rotation'], target['rotation'].to(device)) + + # Torsion circular MSE + if pred['torsion'].numel() > 0 and target['torsion_changes'].numel() > 0: + target_tor = target['torsion_changes'].to(device) + diff = pred['torsion'] - target_tor + diff = torch.atan2(torch.sin(diff), torch.cos(diff)) # wrap to [-pi, pi] + loss_tor = (diff ** 2).mean() + else: + loss_tor = torch.zeros(1, device=device).squeeze() + + # Coordinate reconstruction loss + loss_coord = torch.zeros(1, device=device).squeeze() + if w_coord > 0: + reconstructed = reconstruct_coords( + coords_x0, pred, mask_rotate, rotatable_edges, batch_indices + ) + loss_coord = F.mse_loss(reconstructed, coords_x1.to(device)) + + total = w_trans * loss_trans + w_rot * loss_rot + w_tor * loss_tor + w_coord * loss_coord + + return { + 'total': total, + 'translation': loss_trans.detach(), + 'rotation': loss_rot.detach(), + 'torsion': loss_tor.detach(), + 'coord_recon': loss_coord.detach(), + } + + +def reconstruct_coords( + coords_x0: torch.Tensor, + pred: dict, + mask_rotate: torch.Tensor, + rotatable_edges: torch.Tensor, + batch_indices: torch.Tensor, +) -> torch.Tensor: + """ + Reconstruct coordinates from SE(3) + Torsion prediction. + + Apply order: Torsion → Translation → Rotation. + """ + device = coords_x0.device + coords = coords_x0.clone() + batch_size = pred['translation'].shape[0] + + for b in range(batch_size): + mol_mask = (batch_indices == b) + mol_coords = coords[mol_mask] + mol_indices = torch.where(mol_mask)[0] + n_atoms = mol_indices.shape[0] + offset = mol_indices[0].item() + + # 1. Torsion + if pred['torsion'].numel() > 0 and mask_rotate.shape[0] > 0: + mol_coords = _apply_torsions( + mol_coords, pred['torsion'], mask_rotate, + rotatable_edges, offset, n_atoms, device + ) + + # 2. Translation + mol_coords = mol_coords + pred['translation'][b] + + # 3. Rotation around CoM + mol_coords = _apply_rotation(mol_coords, pred['rotation'][b]) + + coords[mol_mask] = mol_coords + + return coords + + +def _apply_torsions(mol_coords, torsion_values, mask_rotate, rotatable_edges, offset, n_atoms, device): + """Apply torsion angle changes to molecule coordinates.""" + for m in range(mask_rotate.shape[0]): + angle = torsion_values[m] + if angle.abs() < 1e-6: + continue + + mask = mask_rotate[m, offset:offset + n_atoms] + if not mask.any(): + continue + + src, dst = rotatable_edges[m] + src_l = src.item() - offset + dst_l = dst.item() - offset + if not (0 <= src_l < n_atoms and 0 <= dst_l < n_atoms): + continue + + axis = mol_coords[dst_l] - mol_coords[src_l] + axis_norm = axis.norm() + if axis_norm < 1e-6: + continue + axis = axis / axis_norm + + mol_coords = _rodrigues_rotate(mol_coords, mask, axis, mol_coords[dst_l], angle) + + return mol_coords + + +def _rodrigues_rotate(coords, mask, axis, pivot, angle): + """Rodrigues rotation of masked atoms around axis through pivot.""" + relative = coords[mask] - pivot + cos_a = torch.cos(angle) + sin_a = torch.sin(angle) + dot = (relative * axis).sum(dim=-1, keepdim=True) + cross = torch.cross(axis.unsqueeze(0).expand_as(relative), relative, dim=-1) + rotated = relative * cos_a + cross * sin_a + axis * dot * (1 - cos_a) + coords = coords.clone() + coords[mask] = rotated + pivot + return coords + + +def _apply_rotation(coords, rot_vec): + """Apply axis-angle rotation around center of mass.""" + angle = rot_vec.norm() + if angle < 1e-6: + return coords + com = coords.mean(dim=0) + relative = coords - com + axis = rot_vec / angle + cos_a = torch.cos(angle) + sin_a = torch.sin(angle) + dot = (relative * axis).sum(dim=-1, keepdim=True) + cross = torch.cross(axis.unsqueeze(0).expand_as(relative), relative, dim=-1) + rotated = relative * cos_a + cross * sin_a + axis * dot * (1 - cos_a) + return rotated + com diff --git a/src/utils/metrics.py b/src/utils/metrics.py deleted file mode 100644 index 5b17c95..0000000 --- a/src/utils/metrics.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Metrics utilities for FlowFix training. - -Contains functions for: -- RMSD calculation -- Success rate calculation -- DDP metric gathering -""" - -from typing import List, Dict, Any -import numpy as np -import torch -import torch.distributed as dist - - -def compute_success_rates(rmsds: np.ndarray) -> Dict[str, float]: - """ - Compute success rates at various RMSD thresholds. - - Args: - rmsds: Array of RMSD values - - Returns: - Dict with success rates: 'success_2A', 'success_1A', 'success_05A' - """ - return { - "success_2A": float(np.mean(rmsds < 2.0) * 100), - "success_1A": float(np.mean(rmsds < 1.0) * 100), - "success_05A": float(np.mean(rmsds < 0.5) * 100), - } - - -def gather_metrics_ddp( - local_metrics: Dict[str, List[float]], - world_size: int, - device: torch.device, -) -> Dict[str, np.ndarray]: - """ - Gather metrics from all DDP processes. - - Handles variable-length lists across processes by padding before all_gather. - - Args: - local_metrics: Dict of metric_name -> list of values - world_size: Number of DDP processes - device: Torch device - - Returns: - Dict of metric_name -> numpy array of all gathered values - """ - gathered = {} - - for metric_name, values in local_metrics.items(): - values_tensor = torch.tensor(values, device=device) - - # Gather sizes from all ranks - local_size = torch.tensor([len(values)], device=device, dtype=torch.long) - all_sizes = [ - torch.zeros(1, device=device, dtype=torch.long) for _ in range(world_size) - ] - dist.all_gather(all_sizes, local_size) - - # Find max size for padding - max_size = max(s.item() for s in all_sizes) - - # Pad tensor to max size - padded = torch.zeros(max_size, device=device) - padded[: len(values)] = values_tensor - - # Gather padded tensors - gathered_tensors = [ - torch.zeros(max_size, device=device) for _ in range(world_size) - ] - dist.all_gather(gathered_tensors, padded) - - # Unpad and concatenate - unpadded = [ - gathered_tensors[i][: all_sizes[i].item()] for i in range(world_size) - ] - gathered[metric_name] = torch.cat(unpadded).cpu().numpy() - - return gathered - - -def average_metrics_ddp( - local_metrics: List[float], - world_size: int, - device: torch.device, -) -> float: - """ - Average a list of metrics across all DDP processes. - - Args: - local_metrics: List of metric values from this process - world_size: Number of DDP processes - device: Torch device - - Returns: - Averaged metric value - """ - # Convert to tensor and compute local stats - tensor = torch.tensor(local_metrics, device=device) - - # All-reduce sum - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - - # Average across GPUs and samples - return tensor.mean().item() / world_size diff --git a/src/utils/plot.py b/src/utils/plot.py deleted file mode 100644 index f12b244..0000000 --- a/src/utils/plot.py +++ /dev/null @@ -1,762 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import os -import seaborn as sns -from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score -import itertools -from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score -from scipy.stats import pearsonr -import pandas as pd - -def plot_scatter(y_true, y_pred, save_dir, type='valid', logger=None): - """ - Draw scatter plot of true vs predicted values with regression metrics. - - Args: - y_true: Ground truth values - y_pred: Predicted values - save_dir: Directory to save the plot - type: Type of data ('train' or 'valid') - logger: Logger object for logging messages - """ - plt.figure(figsize=(10, 10)) - - # Calculate metrics - rmse = np.sqrt(mean_squared_error(y_true, y_pred)) - mae = mean_absolute_error(y_true, y_pred) - r2 = r2_score(y_true, y_pred) - pearson_corr, _ = pearsonr(y_true, y_pred) - - # Create scatter plot - plt.scatter(y_true, y_pred, alpha=0.5, color='blue', label='Data points') - - # Plot the perfect prediction line - min_val = min(min(y_true), min(y_pred)) - max_val = max(max(y_true), max(y_pred)) - plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect prediction') - - # Add metrics text box - metrics_text = f'RMSE: {rmse:.4f}\nMAE: {mae:.4f}\nR²: {r2:.4f}\nPearson: {pearson_corr:.4f}' - plt.text(0.05, 0.95, metrics_text, - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), - verticalalignment='top', - fontsize=10) - - plt.xlabel('True Values') - plt.ylabel('Predicted Values') - plt.title(f'{type.capitalize()} Set: True vs Predicted Values') - plt.legend() - plt.grid(True, linestyle='--', alpha=0.7) - - # Make plot square with equal axes - plt.axis('square') - - # Add colorbar to show density - from scipy.stats import gaussian_kde - xy = np.vstack([y_true, y_pred]) - z = gaussian_kde(xy)(xy) - idx = z.argsort() - x, y, z = y_true[idx], y_pred[idx], z[idx] - scatter = plt.scatter(x, y, c=z, s=50, alpha=0.5, cmap='viridis') - plt.colorbar(scatter, label='Density') - - save_path = os.path.join(save_dir, f'{type}_scatter_plot.png') - plt.savefig(save_path, dpi=300, bbox_inches='tight') - plt.close() - - if logger: - logger.info(f"Scatter plot saved to {save_path}") - - return { - 'RMSE': rmse, - 'MAE': mae, - 'R2': r2, - 'Pearson': pearson_corr - } - -def plot_losses(train_loss, valid_loss, save_dir, best_epoch=None, logger=None): - """ - Plot training and validation losses over epochs - - Args: - train_loss: List of training loss values - valid_loss: List of validation loss values - save_dir: Directory to save the plot - best_epoch: Best epoch index (0-based) - logger: Logger object - """ - if not train_loss or not valid_loss: - if logger: - logger.warning("No loss data to plot") - return - - # Get epochs for x-axis (assuming train_loss and valid_loss have the same length) - start_epoch = 1 # default start epoch - epochs = list(range(start_epoch, start_epoch + len(train_loss))) - - plt.figure(figsize=(10, 6)) - plt.plot(epochs, train_loss, 'b-', label='Training Loss') - plt.plot(epochs, valid_loss, 'r-', label='Validation Loss') - - # Always show best_epoch if provided, but display all epochs regardless - if best_epoch: - plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch') - - plt.title('Training and Validation Loss') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.legend() - plt.grid(True) - - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'loss_history.png'), dpi=300) - plt.close() - - if logger: - logger.info(f"Loss history plot saved to {save_dir}/loss_history.png") - -def plot_metrics(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None): - """Plot metrics history - - Args: - train_metrics: List of training metrics dictionaries - valid_metrics: List of validation metrics dictionaries - save_dir: Directory to save plots - best_epoch: Marker for best epoch (optional) - logger: Logger for logging messages - """ - if not train_metrics or not valid_metrics: - if logger: - logger.warning("No metrics to plot") - return - - # Create figures directory - figures_dir = os.path.join(save_dir, 'figures') - os.makedirs(figures_dir, exist_ok=True) - - # Plot regression metrics - plot_regression_metrics(train_metrics, valid_metrics, figures_dir, best_epoch=best_epoch, logger=logger) - - if logger: - logger.info(f"Saved metrics plots to {figures_dir}") - -def plot_regression_metrics(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None): - """Plot regression metrics - - Args: - train_metrics: List of training metrics dictionaries - valid_metrics: List of validation metrics dictionaries - save_dir: Directory to save plots - best_epoch: Marker for best epoch (optional) - logger: Logger for logging messages - """ - # Check if metrics are available - if ( - len(valid_metrics) == 0 or - 'metrics' not in valid_metrics[0] or - 'regression' not in valid_metrics[0]['metrics'] - ): - if logger: - logger.warning("No regression metrics to plot") - return - - # Extract regression metrics - epochs = list(range(1, len(train_metrics) + 1)) - - # Get all metrics (will be available in validation results) - metric_names = list(valid_metrics[0]['metrics']['regression'].keys()) - - # Create plot for each metric - for metric_name in metric_names: - try: - valid_values = [valid_metrics[i]['metrics']['regression'][metric_name] for i in range(len(valid_metrics))] - - plt.figure(figsize=(10, 6)) - - # Plot only validation values (training may not have all metrics) - plt.plot(epochs, valid_values, 'b-', label=f'Validation {metric_name}') - - if best_epoch: - # Mark best epoch - plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})') - - # Get best value - best_value = valid_values[best_epoch - 1] if 0 <= best_epoch - 1 < len(valid_values) else None - if best_value is not None: - plt.scatter([best_epoch], [best_value], color='r', s=100, zorder=5) - plt.text(best_epoch, best_value, f' {best_value:.4f}', verticalalignment='bottom') - - plt.xlabel('Epoch') - plt.ylabel(metric_name) - plt.title(f'Regression {metric_name} vs Epoch') - plt.grid(True, linestyle='--', alpha=0.7) - plt.legend() - - # Save figure - plt.tight_layout() - plt.savefig(os.path.join(save_dir, f'regression_{metric_name.lower()}.png'), dpi=300) - plt.close() - - except Exception as e: - if logger: - logger.warning(f"Error plotting {metric_name}: {str(e)}") - continue - - if logger: - logger.info(f"Saved regression metrics plots to {save_dir}") - -def plot_confusion_matrix(cm, classes, save_dir, title='Confusion Matrix', task_type=None, cmap=plt.cm.Blues): - """ - Plot confusion matrix - - Args: - cm: Confusion matrix - classes: List of class names - save_dir: Directory to save the plot - title: Plot title - task_type: Task type for filename (optional) - cmap: Colormap - """ - plt.figure(figsize=(8, 6)) - plt.imshow(cm, interpolation='nearest', cmap=cmap) - plt.title(title) - plt.colorbar() - - tick_marks = np.arange(len(classes)) - plt.xticks(tick_marks, classes, rotation=45) - plt.yticks(tick_marks, classes) - - fmt = 'd' - thresh = cm.max() / 2. - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - plt.text(j, i, format(cm[i, j], fmt), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black") - - plt.tight_layout() - plt.ylabel('True label') - plt.xlabel('Predicted label') - - # 파일명 설정 (task_type이 있으면 사용, 없으면 기본 파일명) - filename = f"{task_type}_cm.png" if task_type else "confusion_matrix.png" - plt.savefig(os.path.join(save_dir, filename)) - plt.close() - -def plot_binding_metrics(y_true, y_pred, save_dir, task_type='binding', epoch=None, logger=None): - """ - Plot metrics for binding or non-binding classification tasks in a single figure - with 3 subplots: confusion matrix, ROC curve, and PR curve. - - Args: - y_true: True labels - y_pred: Predicted probabilities - save_dir: Directory to save the plot - task_type: 'binding' or 'non_binding' - epoch: Current epoch number (not used in filename anymore) - logger: Logger object - """ - # Convert predictions to binary (0/1) using 0.5 threshold - y_pred_binary = (y_pred >= 0.5).astype(int) - - # Compute confusion matrix - cm = confusion_matrix(y_true, y_pred_binary) - - # Create a figure with 3 subplots - fig, axes = plt.subplots(1, 3, figsize=(18, 6)) - plt.suptitle(f'{task_type.replace("_", " ").title()} Classification Metrics', fontsize=16) - - # 1. Plot confusion matrix - axes[0].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) - axes[0].set_title('Confusion Matrix') - - # Add text annotations to confusion matrix - thresh = cm.max() / 2. - for i in range(cm.shape[0]): - for j in range(cm.shape[1]): - axes[0].text(j, i, format(cm[i, j], 'd'), - ha="center", va="center", - color="white" if cm[i, j] > thresh else "black") - - axes[0].set_xticks([0, 1]) - axes[0].set_xticklabels(['Negative', 'Positive']) - axes[0].set_yticks([0, 1]) - axes[0].set_yticklabels(['Negative', 'Positive']) - axes[0].set_ylabel('True label') - axes[0].set_xlabel('Predicted label') - - # 2. Plot ROC curve - fpr, tpr, _ = roc_curve(y_true, y_pred) - roc_auc = auc(fpr, tpr) - - axes[1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})') - axes[1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') - axes[1].set_xlim([0.0, 1.0]) - axes[1].set_ylim([0.0, 1.05]) - axes[1].set_xlabel('False Positive Rate') - axes[1].set_ylabel('True Positive Rate') - axes[1].set_title('ROC Curve') - axes[1].legend(loc='lower right') - axes[1].grid(True, alpha=0.3) - - # 3. Plot precision-recall curve - precision, recall, _ = precision_recall_curve(y_true, y_pred) - avg_precision = average_precision_score(y_true, y_pred) - - axes[2].plot(recall, precision, color='blue', lw=2, label=f'PR curve (AP = {avg_precision:.3f})') - axes[2].set_xlabel('Recall') - axes[2].set_ylabel('Precision') - axes[2].set_title('Precision-Recall Curve') - axes[2].legend(loc='upper right') - axes[2].grid(True, alpha=0.3) - axes[2].set_ylim([0.0, 1.05]) - - plt.tight_layout() - - # Save the combined figure without epoch number (already implemented) - save_path = os.path.join(save_dir, f'{task_type}_metrics.png') - plt.savefig(save_path, dpi=300) - plt.close() - - if logger: - logger.info(f"{task_type.replace('_', ' ').title()} metrics plot saved to {save_path}") - -def plot_loss_types(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None): - """ - Plot different types of losses over training epochs in a 2x2 subplot format - """ - # Always start from the first epoch - if 'epoch' in train_metrics[0]: - start_epoch = train_metrics[0]['epoch'] - else: - start_epoch = 1 - - epochs = list(range(start_epoch, start_epoch + len(train_metrics))) - - # Plot loss types (reg_loss, bind_loss, nonbind_loss, total_loss) - plt.figure(figsize=(16, 12)) - plt.suptitle('Loss Types Over Training', fontsize=16) - - # Total Loss - plt.subplot(2, 2, 1) - plt.plot(epochs, [m['total_loss'] for m in train_metrics], 'b-', label='Training') - plt.plot(epochs, [m['losses']['total_loss'] for m in valid_metrics], 'r-', label='Validation') - if best_epoch: - plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title('Total Loss') - plt.legend() - plt.grid(True) - - # Regression Loss - plt.subplot(2, 2, 2) - plt.plot(epochs, [m['reg_loss'] for m in train_metrics], 'b-', label='Training') - plt.plot(epochs, [m['losses']['reg_loss'] for m in valid_metrics], 'r-', label='Validation') - if best_epoch: - plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title('Regression Loss') - plt.legend() - plt.grid(True) - - # Binding Loss - plt.subplot(2, 2, 3) - plt.plot(epochs, [m['bind_loss'] for m in train_metrics], 'b-', label='Training') - plt.plot(epochs, [m['losses']['bind_loss'] for m in valid_metrics], 'r-', label='Validation') - if best_epoch: - plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title('Binding Loss') - plt.legend() - plt.grid(True) - - # Non-binding Loss - plt.subplot(2, 2, 4) - plt.plot(epochs, [m['nonbind_loss'] for m in train_metrics], 'b-', label='Training') - plt.plot(epochs, [m['losses']['nonbind_loss'] for m in valid_metrics], 'r-', label='Validation') - if best_epoch: - plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title('Non-binding Loss') - plt.legend() - plt.grid(True) - - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'loss_types.png'), dpi=300) - plt.close() - - if logger: - logger.info(f"Loss types plot saved to {save_dir}/loss_types.png") - -# Set style -plt.style.use('seaborn-v0_8') -sns.set_palette("husl") - -def plot_loss_curves(train_losses, valid_losses, save_dir, title="Training History"): - """ - Plot training and validation loss curves - - Args: - train_losses: List of training losses - valid_losses: List of validation losses - save_dir: Directory to save the plot - title: Plot title - """ - plt.figure(figsize=(12, 5)) - - # Loss subplot - plt.subplot(1, 2, 1) - epochs = range(1, len(train_losses) + 1) - - plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2) - plt.plot(epochs, valid_losses, 'r-', label='Validation Loss', linewidth=2) - - plt.title('Loss Curves') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.legend() - plt.grid(True, alpha=0.3) - - # Find best epoch (minimum validation loss) - best_epoch = np.argmin(valid_losses) + 1 - plt.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}') - plt.legend() - - # Loss difference subplot - plt.subplot(1, 2, 2) - loss_diff = np.array(valid_losses) - np.array(train_losses) - plt.plot(epochs, loss_diff, 'purple', linewidth=2) - plt.title('Validation - Training Loss') - plt.xlabel('Epoch') - plt.ylabel('Loss Difference') - plt.grid(True, alpha=0.3) - plt.axhline(y=0, color='black', linestyle='-', alpha=0.3) - - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'loss_curves.png'), dpi=300, bbox_inches='tight') - plt.close() - - print(f"Loss curves saved to {save_dir}/loss_curves.png") - -def plot_r2_curves(train_r2_scores, valid_r2_scores, save_dir, title="R² Score History"): - """ - Plot training and validation R² score curves - - Args: - train_r2_scores: List of training R² scores - valid_r2_scores: List of validation R² scores - save_dir: Directory to save the plot - title: Plot title - """ - plt.figure(figsize=(12, 5)) - - # R² subplot - plt.subplot(1, 2, 1) - epochs = range(1, len(train_r2_scores) + 1) - - plt.plot(epochs, train_r2_scores, 'b-', label='Training R²', linewidth=2) - plt.plot(epochs, valid_r2_scores, 'r-', label='Validation R²', linewidth=2) - - plt.title('R² Score Curves') - plt.xlabel('Epoch') - plt.ylabel('R² Score') - plt.legend() - plt.grid(True, alpha=0.3) - - # Find best epoch (maximum validation R²) - best_epoch = np.argmax(valid_r2_scores) + 1 - best_r2 = valid_r2_scores[best_epoch - 1] - plt.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}') - plt.scatter([best_epoch], [best_r2], color='g', s=100, zorder=5) - plt.text(best_epoch, best_r2, f' {best_r2:.4f}', verticalalignment='bottom', fontweight='bold') - plt.legend() - - # R² difference subplot - plt.subplot(1, 2, 2) - r2_diff = np.array(train_r2_scores) - np.array(valid_r2_scores) - plt.plot(epochs, r2_diff, 'purple', linewidth=2) - plt.title('Training - Validation R²') - plt.xlabel('Epoch') - plt.ylabel('R² Difference') - plt.grid(True, alpha=0.3) - plt.axhline(y=0, color='black', linestyle='-', alpha=0.3) - - # Add interpretation guidelines - plt.text(0.02, 0.98, 'Positive: Overfitting\nNegative: Underfitting', - transform=plt.gca().transAxes, verticalalignment='top', - bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) - - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'r2_curves.png'), dpi=300, bbox_inches='tight') - plt.close() - - print(f"R² curves saved to {save_dir}/r2_curves.png") - - # Print summary statistics - print(f"Best validation R²: {best_r2:.4f} at epoch {best_epoch}") - print(f"Final validation R²: {valid_r2_scores[-1]:.4f}") - print(f"Max training R²: {max(train_r2_scores):.4f}") - print(f"Max validation R²: {max(valid_r2_scores):.4f}") - -def plot_predictions(true_values, predictions, save_dir, title="Prediction Results"): - """Plot prediction vs true values scatter plot""" - # Create figure with subplots - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) - - # Convert to numpy arrays - true_values = np.array(true_values) - predictions = np.array(predictions) - - # Calculate metrics - mse = mean_squared_error(true_values, predictions) - mae = mean_absolute_error(true_values, predictions) - rmse = np.sqrt(mse) - r2 = r2_score(true_values, predictions) - pearson_r, pearson_p = pearsonr(true_values, predictions) - - # 1. Scatter plot - ax1.scatter(true_values, predictions, alpha=0.6, s=30) - - # Perfect prediction line - min_val = min(true_values.min(), predictions.min()) - max_val = max(true_values.max(), predictions.max()) - ax1.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction') - - ax1.set_xlabel('True Affinity', fontsize=12) - ax1.set_ylabel('Predicted Affinity', fontsize=12) - ax1.set_title(f'Predictions vs True Values\nR² = {r2:.4f}, RMSE = {rmse:.4f}', fontsize=14) - ax1.legend() - ax1.grid(True, alpha=0.3) - - # 2. Residual plot - residuals = predictions - true_values - ax2.scatter(predictions, residuals, alpha=0.6, s=30) - ax2.axhline(y=0, color='r', linestyle='--', lw=2) - ax2.set_xlabel('Predicted Affinity', fontsize=12) - ax2.set_ylabel('Residuals', fontsize=12) - ax2.set_title(f'Residual Plot\nMAE = {mae:.4f}', fontsize=14) - ax2.grid(True, alpha=0.3) - - # 3. Distribution of residuals - ax3.hist(residuals, bins=50, alpha=0.7, edgecolor='black') - ax3.axvline(residuals.mean(), color='r', linestyle='--', lw=2, - label=f'Mean: {residuals.mean():.4f}') - ax3.axvline(0, color='g', linestyle='-', lw=2, label='Zero') - ax3.set_xlabel('Residuals', fontsize=12) - ax3.set_ylabel('Frequency', fontsize=12) - ax3.set_title('Distribution of Residuals', fontsize=14) - ax3.legend() - ax3.grid(True, alpha=0.3) - - # 4. Error distribution - abs_errors = np.abs(residuals) - ax4.hist(abs_errors, bins=50, alpha=0.7, edgecolor='black', color='orange') - ax4.axvline(abs_errors.mean(), color='r', linestyle='--', lw=2, - label=f'Mean AE: {abs_errors.mean():.4f}') - ax4.axvline(np.median(abs_errors), color='g', linestyle='--', lw=2, - label=f'Median AE: {np.median(abs_errors):.4f}') - ax4.set_xlabel('Absolute Error', fontsize=12) - ax4.set_ylabel('Frequency', fontsize=12) - ax4.set_title('Distribution of Absolute Errors', fontsize=14) - ax4.legend() - ax4.grid(True, alpha=0.3) - - # Add overall metrics text - metrics_text = f""" - Metrics Summary: - ├─ RMSE: {rmse:.4f} - ├─ MAE: {mae:.4f} - ├─ R²: {r2:.4f} - ├─ Pearson R: {pearson_r:.4f} - └─ Pearson P: {pearson_p:.4e} - """ - - fig.text(0.02, 0.02, metrics_text, fontsize=10, fontfamily='monospace', - bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)) - - plt.suptitle(title, fontsize=16, fontweight='bold') - plt.tight_layout() - plt.subplots_adjust(top=0.93, bottom=0.15) - plt.savefig(os.path.join(save_dir, 'prediction_results.png'), dpi=300, bbox_inches='tight') - plt.close() - -def plot_error_analysis(true_values, predictions, save_dir): - """Plot detailed error analysis""" - true_values = np.array(true_values) - predictions = np.array(predictions) - residuals = predictions - true_values - abs_errors = np.abs(residuals) - - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) - - # 1. Error vs True Values - ax1.scatter(true_values, abs_errors, alpha=0.6, s=30) - ax1.set_xlabel('True Affinity', fontsize=12) - ax1.set_ylabel('Absolute Error', fontsize=12) - ax1.set_title('Absolute Error vs True Values', fontsize=14) - ax1.grid(True, alpha=0.3) - - # Add trend line - z = np.polyfit(true_values, abs_errors, 1) - p = np.poly1d(z) - ax1.plot(sorted(true_values), p(sorted(true_values)), "r--", alpha=0.8) - - # 2. Error vs Predicted Values - ax2.scatter(predictions, abs_errors, alpha=0.6, s=30, color='orange') - ax2.set_xlabel('Predicted Affinity', fontsize=12) - ax2.set_ylabel('Absolute Error', fontsize=12) - ax2.set_title('Absolute Error vs Predictions', fontsize=14) - ax2.grid(True, alpha=0.3) - - # 3. Q-Q Plot for residuals - from scipy import stats - stats.probplot(residuals, dist="norm", plot=ax3) - ax3.set_title('Q-Q Plot of Residuals', fontsize=14) - ax3.grid(True, alpha=0.3) - - # 4. Cumulative error distribution - sorted_errors = np.sort(abs_errors) - cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors) - ax4.plot(sorted_errors, cumulative, linewidth=2) - ax4.set_xlabel('Absolute Error', fontsize=12) - ax4.set_ylabel('Cumulative Probability', fontsize=12) - ax4.set_title('Cumulative Error Distribution', fontsize=14) - ax4.grid(True, alpha=0.3) - - # Add percentile lines - percentiles = [0.5, 0.8, 0.9, 0.95] - for p in percentiles: - error_at_p = np.percentile(abs_errors, p * 100) - ax4.axvline(error_at_p, color='red', linestyle='--', alpha=0.7) - ax4.text(error_at_p, p, f'{p*100:.0f}%', rotation=90, - verticalalignment='bottom', fontsize=10) - - plt.suptitle('Error Analysis', fontsize=16, fontweight='bold') - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'error_analysis.png'), dpi=300, bbox_inches='tight') - plt.close() - -def plot_affinity_distribution(true_values, predictions, save_dir): - """Plot distribution comparison of true vs predicted affinities""" - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) - - # 1. True values distribution - ax1.hist(true_values, bins=50, alpha=0.7, label='True Affinity', color='blue', edgecolor='black') - ax1.set_xlabel('Affinity', fontsize=12) - ax1.set_ylabel('Frequency', fontsize=12) - ax1.set_title('True Affinity Distribution', fontsize=14) - ax1.legend() - ax1.grid(True, alpha=0.3) - - # 2. Predicted values distribution - ax2.hist(predictions, bins=50, alpha=0.7, label='Predicted Affinity', color='red', edgecolor='black') - ax2.set_xlabel('Affinity', fontsize=12) - ax2.set_ylabel('Frequency', fontsize=12) - ax2.set_title('Predicted Affinity Distribution', fontsize=14) - ax2.legend() - ax2.grid(True, alpha=0.3) - - # 3. Overlapped distributions - ax3.hist(true_values, bins=50, alpha=0.5, label='True Affinity', color='blue', density=True) - ax3.hist(predictions, bins=50, alpha=0.5, label='Predicted Affinity', color='red', density=True) - ax3.set_xlabel('Affinity', fontsize=12) - ax3.set_ylabel('Density', fontsize=12) - ax3.set_title('Distribution Comparison', fontsize=14) - ax3.legend() - ax3.grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig(os.path.join(save_dir, 'affinity_distributions.png'), dpi=300, bbox_inches='tight') - plt.close() - -def create_summary_report(true_values, predictions, save_dir, model_info=None): - """Create a comprehensive summary report""" - true_values = np.array(true_values) - predictions = np.array(predictions) - - # Calculate all metrics - mse = mean_squared_error(true_values, predictions) - mae = mean_absolute_error(true_values, predictions) - rmse = np.sqrt(mse) - r2 = r2_score(true_values, predictions) - pearson_r, pearson_p = pearsonr(true_values, predictions) - - residuals = predictions - true_values - abs_errors = np.abs(residuals) - - # Create summary text - report = f""" -# Model Performance Report - -## Model Information -{f"- Model: {model_info.get('model_name', 'EGNN Affinity Model')}" if model_info else "- Model: EGNN Affinity Model"} -{f"- Parameters: {model_info.get('total_parameters', 'N/A')}" if model_info else ""} -{f"- Training Time: {model_info.get('training_time', 'N/A')}" if model_info else ""} - -## Dataset Statistics -- Total Samples: {len(true_values):,} -- True Affinity Range: [{true_values.min():.4f}, {true_values.max():.4f}] -- True Affinity Mean ± Std: {true_values.mean():.4f} ± {true_values.std():.4f} -- Predicted Affinity Range: [{predictions.min():.4f}, {predictions.max():.4f}] -- Predicted Affinity Mean ± Std: {predictions.mean():.4f} ± {predictions.std():.4f} - -## Performance Metrics -### Primary Metrics -- Root Mean Square Error (RMSE): {rmse:.4f} -- Mean Absolute Error (MAE): {mae:.4f} -- R-squared Score (R²): {r2:.4f} -- Pearson Correlation: {pearson_r:.4f} (p-value: {pearson_p:.2e}) - -### Error Statistics -- Mean Residual: {residuals.mean():.4f} -- Std Residual: {residuals.std():.4f} -- Mean Absolute Error: {abs_errors.mean():.4f} -- Median Absolute Error: {np.median(abs_errors):.4f} -- 90th Percentile Error: {np.percentile(abs_errors, 90):.4f} -- 95th Percentile Error: {np.percentile(abs_errors, 95):.4f} -- Max Absolute Error: {abs_errors.max():.4f} - -### Error Distribution -- Errors < 0.5: {(abs_errors < 0.5).sum():,} ({(abs_errors < 0.5).mean()*100:.1f}%) -- Errors < 1.0: {(abs_errors < 1.0).sum():,} ({(abs_errors < 1.0).mean()*100:.1f}%) -- Errors < 2.0: {(abs_errors < 2.0).sum():,} ({(abs_errors < 2.0).mean()*100:.1f}%) -- Errors ≥ 2.0: {(abs_errors >= 2.0).sum():,} ({(abs_errors >= 2.0).mean()*100:.1f}%) - -## Summary -The model achieves {'excellent' if r2 > 0.8 else 'good' if r2 > 0.6 else 'moderate' if r2 > 0.4 else 'poor'} -performance with an R² score of {r2:.4f} and RMSE of {rmse:.4f}. -""" - - # Save report - with open(os.path.join(save_dir, 'performance_report.md'), 'w') as f: - f.write(report) - - # Also create a CSV with detailed results - results_df = pd.DataFrame({ - 'True_Affinity': true_values, - 'Predicted_Affinity': predictions, - 'Residual': residuals, - 'Absolute_Error': abs_errors - }) - - results_df.to_csv(os.path.join(save_dir, 'detailed_results.csv'), index=False) - - print("Summary report and detailed results saved.") - -def plot_all_results(true_values, predictions, save_dir, model_info=None): - """Generate all plots and reports""" - print("Generating comprehensive result analysis...") - - # Generate all plots - plot_predictions(true_values, predictions, save_dir) - plot_error_analysis(true_values, predictions, save_dir) - plot_affinity_distribution(true_values, predictions, save_dir) - - # Create summary report - create_summary_report(true_values, predictions, save_dir, model_info) - - print(f"All results saved to: {save_dir}") \ No newline at end of file diff --git a/src/utils/sampling_torsion.py b/src/utils/sampling_torsion.py new file mode 100644 index 0000000..2658f61 --- /dev/null +++ b/src/utils/sampling_torsion.py @@ -0,0 +1,111 @@ +""" +ODE sampling for SE(3) + Torsion decomposition. + +Applies Torsion → Translation → Rotation at each Euler step. +""" + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +from .losses_torsion import _rodrigues_rotate, _apply_rotation + + +@torch.no_grad() +def sample_trajectory_torsion( + model: torch.nn.Module, + protein_batch, + ligand_batch, + x0: torch.Tensor, + timesteps: torch.Tensor, + rotatable_edges: torch.Tensor, + mask_rotate: torch.Tensor, + return_trajectory: bool = False, +) -> dict: + """ + Sample trajectory using SE(3) + Torsion decomposition. + + Args: + model: Flow matching model (ProteinLigandFlowMatchingTorsion) + protein_batch: Protein PyG batch + ligand_batch: Ligand PyG batch + x0: Initial docked coordinates [N_atoms, 3] + timesteps: Integration timesteps [num_steps + 1] + rotatable_edges: [M, 2] rotatable bond atom indices + mask_rotate: [M, N] boolean mask for torsion + return_trajectory: Whether to return full trajectory + + Returns: + Dict with 'final_coords', optionally 'trajectory' + """ + device = x0.device + batch_size = ligand_batch.batch.max().item() + 1 + num_steps = len(timesteps) - 1 + + current_coords = x0.clone() + trajectory = [current_coords.clone()] if return_trajectory else [] + + for step in range(num_steps): + t_current = timesteps[step] + t_next = timesteps[step + 1] + dt = (t_next - t_current).item() + + t = torch.ones(batch_size, device=device) * t_current + + ligand_batch_t = ligand_batch.clone() + ligand_batch_t.pos = current_coords.clone() + + # Unwrap DDP if needed + model_fn = model.module if isinstance(model, DDP) else model + output = model_fn(protein_batch, ligand_batch_t, t, rotatable_edges=rotatable_edges) + + # Apply per molecule: Torsion → Translation → Rotation + for b in range(batch_size): + mol_mask = (ligand_batch.batch == b) + mol_coords = current_coords[mol_mask].clone() + mol_indices = torch.where(mol_mask)[0] + n_atoms = mol_indices.shape[0] + offset = mol_indices[0].item() + + # 1. Torsion + if output['torsion'].numel() > 0 and mask_rotate.shape[0] > 0: + for m in range(mask_rotate.shape[0]): + angle = dt * output['torsion'][m].item() + if abs(angle) < 1e-6: + continue + + mask_local = mask_rotate[m, offset:offset + n_atoms] + if not mask_local.any(): + continue + + src, dst = rotatable_edges[m] + src_l = src.item() - offset + dst_l = dst.item() - offset + if not (0 <= src_l < n_atoms and 0 <= dst_l < n_atoms): + continue + + axis = mol_coords[dst_l] - mol_coords[src_l] + axis_norm = axis.norm() + if axis_norm < 1e-6: + continue + axis = axis / axis_norm + + mol_coords = _rodrigues_rotate( + mol_coords, mask_local, axis, mol_coords[dst_l], + torch.tensor(angle, device=device) + ) + + # 2. Translation + mol_coords = mol_coords + dt * output['translation'][b] + + # 3. Rotation + mol_coords = _apply_rotation(mol_coords, dt * output['rotation'][b]) + + current_coords[mol_mask] = mol_coords + + if return_trajectory: + trajectory.append(current_coords.clone()) + + result = {"final_coords": current_coords} + if return_trajectory: + result["trajectory"] = trajectory + return result diff --git a/src/utils/train.py b/src/utils/train.py deleted file mode 100644 index c106f05..0000000 --- a/src/utils/train.py +++ /dev/null @@ -1,205 +0,0 @@ -import dgl, torch, scipy -import numpy as np - -import torch.nn as nn -import torch.nn.functional as F - -from tqdm import tqdm -from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, matthews_corrcoef -from scipy.stats import spearmanr - - -def custom_loss(pred, true, direction='over', penalty_factor=2.0): - base_loss = (pred - true) ** 2 - if direction == 'over': - penalty = torch.where(pred > true, penalty_factor * base_loss, base_loss) - elif direction == 'under': - penalty = torch.where(pred < true, penalty_factor * base_loss, base_loss) - else: - penalty = base_loss - - return torch.mean(penalty) - - - -def run_train_epoch(model, loader, optimizer, scheduler, device='cpu'): - """Run a single training epoch. - - Args: - model: The model to train - loader: DataLoader providing batches - optimizer: Optimizer for updating weights - scheduler: Learning rate scheduler - device: Device to use for computation - - Returns: - Dictionary containing average losses and metrics - """ - model.train() - - total_loss = 0 - num_batches = 0 - - for batch in tqdm(loader, desc="Training", total=len(loader)): - prot_data, ligand_graph, ligand_mask, interaction_mask, reg_true, bind_true, nonbind_true = batch - - # Move data to device - for key in prot_data: - prot_data[key] = prot_data[key].to(device) - - ligand_graph = ligand_graph.to(device) - ligand_mask = ligand_mask.to(device) - interaction_mask = interaction_mask.to(device) - - reg_true = reg_true.to(device) - - # Forward pass - regression only - reg_pred = model(prot_data, ligand_graph, interaction_mask) - - # Calculate loss using custom loss function -#loss = custom_loss(reg_pred.squeeze(), reg_true, direction='over', penalty_factor=2.0) - # Huber - loss = F.huber_loss(reg_pred.squeeze(), reg_true) - - optimizer.zero_grad(set_to_none=True) - loss.backward() - - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - - optimizer.step() - - total_loss += loss.item() - num_batches += 1 - - torch.cuda.empty_cache() - - scheduler.step() - - return { - 'total_loss': total_loss / max(1, num_batches) - } - -@torch.no_grad() -def run_eval_epoch(model, loader, device='cpu'): - """Run a single evaluation epoch. - - Args: - model: The model to evaluate - loader: DataLoader providing batches - device: Device to use for computation - - Returns: - Dictionary containing losses and metrics for regression task - """ - model.eval() - - all_reg_true = [] - all_reg_pred = [] - - total_loss = 0 - num_batches = 0 - - for batch in tqdm(loader, desc="Evaluation", total=len(loader)): - prot_data, ligand_graph, ligand_mask, interaction_mask, reg_true, bind_true, nonbind_true = batch - - # Move data to device - for key in prot_data: - prot_data[key] = prot_data[key].to(device) - - ligand_graph = ligand_graph.to(device) - ligand_mask = ligand_mask.to(device) - interaction_mask = interaction_mask.to(device) - - reg_true = reg_true.to(device) - - # Forward pass - regression only - reg_pred = model(prot_data, ligand_graph, interaction_mask) - - # Calculate loss using custom loss function -#loss = custom_loss(reg_pred.squeeze(), reg_true, direction='over', penalty_factor=2.0) - loss = F.huber_loss(reg_pred.squeeze(), reg_true) - - # Accumulate losses - total_loss += loss.item() - num_batches += 1 - - # Collect predictions and true values - all_reg_true.append(reg_true) - all_reg_pred.append(reg_pred) - - torch.cuda.empty_cache() - - # Concatenate all predictions and true values - reg_true = torch.cat(all_reg_true, dim=0) - reg_pred = torch.cat(all_reg_pred, dim=0) - - # Calculate average loss - avg_loss = total_loss / max(1, num_batches) - - # Compute metrics for regression task - reg_metrics = compute_regression_metrics(reg_true, reg_pred) - - return { - 'losses': { - 'total_loss': avg_loss - }, - 'metrics': { - 'regression': reg_metrics - } - } - - -@torch.no_grad() -def compute_regression_metrics(true, pred): - """Compute comprehensive regression metrics""" - true = true.cpu().numpy() - pred = pred.cpu().numpy().squeeze() - - # Basic error metrics - mse = np.mean((true - pred) ** 2) - rmse = np.sqrt(mse) - mae = np.mean(np.abs(true - pred)) - - # Correlation metrics - pearson = np.corrcoef(true, pred)[0, 1] if len(true) > 1 else 0 - spearman = spearmanr(true, pred)[0] if len(true) > 1 else 0 - - # Explained variance metrics - ss_tot = np.sum((true - true.mean()) ** 2) - ss_res = np.sum((true - pred) ** 2) - r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0 - - residuals = true - pred - mean_bias = np.mean(residuals) # Bias (systematic error) - - return { - 'MSE': mse, - 'RMSE': rmse, - 'MAE': mae, - 'Pearson': pearson, - 'Spearman': spearman, - 'R2': r2, - 'Mean_Bias': mean_bias - } -@torch.no_grad() -def compute_binary_metrics(true, pred): - """Compute essential binary classification metrics""" - true = true.cpu().numpy() - pred_probs = torch.sigmoid(pred).cpu().numpy().squeeze() - pred_class = (pred_probs >= 0.5).astype(int) - - tn, fp, fn, tp = confusion_matrix(true, pred_class).ravel() - - return { - 'accuracy': accuracy_score(true, pred_class), - 'precision': precision_score(true, pred_class, zero_division=0), - 'recall': recall_score(true, pred_class, zero_division=0), - 'f1': f1_score(true, pred_class, zero_division=0), - 'auc_roc': roc_auc_score(true, pred_probs) if len(np.unique(true)) > 1 else 0, - 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0, - 'balanced_accuracy': (tp / (tp + fn) + tn / (tn + fp)) / 2 if (tp + fn) > 0 and (tn + fp) > 0 else 0, - 'mcc': matthews_corrcoef(true, pred_class), - 'f1_macro': f1_score(true, pred_class, average='macro', zero_division=0), - 'npv': tn / (tn + fn) if (tn + fn) > 0 else 0 - } - diff --git a/tests/test_joint_model.py b/tests/test_joint_model.py deleted file mode 100644 index 126afd2..0000000 --- a/tests/test_joint_model.py +++ /dev/null @@ -1,335 +0,0 @@ -""" -Test script for the Joint Graph Architecture. -Verifies forward pass shapes and basic functionality. -""" - -import sys -import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) - -import torch -from torch_geometric.data import Data, Batch - - -def create_dummy_protein_batch(batch_size=2, device='cpu'): - """Create dummy protein batch for testing.""" - graphs = [] - for i in range(batch_size): - n_residues = 30 + i * 10 # Variable sizes - n_edges = n_residues * 8 # ~8 edges per node - - graph = Data( - x=torch.randn(n_residues, 76, device=device), - pos=torch.randn(n_residues, 3, device=device) * 10, - edge_index=torch.randint(0, n_residues, (2, n_edges), device=device), - edge_attr=torch.randn(n_edges, 39, device=device), - node_vector_features=torch.randn(n_residues, 31, 3, device=device), - edge_vector_features=torch.randn(n_edges, 8, 3, device=device), - ) - graphs.append(graph) - - return Batch.from_data_list(graphs) - - -def create_dummy_ligand_batch(batch_size=2, device='cpu'): - """Create dummy ligand batch for testing.""" - graphs = [] - for i in range(batch_size): - n_atoms = 20 + i * 5 # Variable sizes - n_edges = n_atoms * 4 # ~4 edges per atom - - graph = Data( - x=torch.randn(n_atoms, 122, device=device), - pos=torch.randn(n_atoms, 3, device=device) * 5, - edge_index=torch.randint(0, n_atoms, (2, n_edges), device=device), - edge_attr=torch.randn(n_edges, 44, device=device), - ) - graphs.append(graph) - - return Batch.from_data_list(graphs) - - -def test_build_cross_edges(): - """Test cross-edge construction.""" - from src.models.network import build_cross_edges - - protein_pos = torch.randn(30, 3) * 10 - ligand_pos = torch.randn(20, 3) * 5 - - edge_index = build_cross_edges( - protein_pos, ligand_pos, - distance_cutoff=10.0, max_neighbors=8 - ) - - print(f"Cross-edge construction:") - print(f" Protein nodes: 30, Ligand nodes: 20") - print(f" Cross edges: {edge_index.shape[1]} (bidirectional)") - print(f" Edge index shape: {edge_index.shape}") - - # Verify bidirectional: first half P→L, second half L→P - n_half = edge_index.shape[1] // 2 - print(f" P→L edges: {n_half}, L→P edges: {edge_index.shape[1] - n_half}") - - assert edge_index.shape[0] == 2 - assert edge_index.shape[1] > 0 - print(" PASSED") - - -def test_build_intra_edges(): - """Test intra-edge construction.""" - from src.models.network import build_intra_edges - - n_nodes = 30 - pos = torch.randn(n_nodes, 3) * 10 - - # Create some existing edges (simulating pre-computed backbone edges) - existing_src = torch.arange(0, n_nodes - 1) - existing_dst = torch.arange(1, n_nodes) - existing_edge_index = torch.stack([ - torch.cat([existing_src, existing_dst]), - torch.cat([existing_dst, existing_src]) - ]) # Bidirectional sequential edges - - edge_index = build_intra_edges( - pos, existing_edge_index, - distance_cutoff=15.0, max_neighbors=8 - ) - - print(f"\nIntra-edge construction:") - print(f" Nodes: {n_nodes}") - print(f" Existing edges: {existing_edge_index.shape[1]}") - print(f" New intra edges: {edge_index.shape[1]}") - print(f" Edge index shape: {edge_index.shape}") - - assert edge_index.shape[0] == 2 - # Verify no overlap with existing edges - if edge_index.shape[1] > 0: - existing_keys = set((existing_edge_index[0].long() * n_nodes + existing_edge_index[1].long()).tolist()) - new_keys = (edge_index[0].long() * n_nodes + edge_index[1].long()).tolist() - overlap = sum(1 for k in new_keys if k in existing_keys) - print(f" Overlap with existing: {overlap} (should be 0)") - assert overlap == 0 - print(" PASSED") - - -def test_joint_network(): - """Test JointProteinLigandNetwork forward pass.""" - from src.models.network import JointProteinLigandNetwork - - device = 'cpu' - batch_size = 2 - - network = JointProteinLigandNetwork( - protein_input_scalar_dim=76, - protein_input_vector_dim=31, - protein_input_edge_scalar_dim=39, - protein_input_edge_vector_dim=8, - ligand_input_scalar_dim=122, - ligand_input_edge_scalar_dim=44, - hidden_scalar_dim=32, # Small for testing - hidden_vector_dim=8, - hidden_edge_dim=32, - cross_edge_distance_cutoff=15.0, # Larger cutoff for random data - cross_edge_max_neighbors=8, - cross_edge_num_rbf=16, - intra_edge_distance_cutoff=15.0, - intra_edge_max_neighbors=8, - num_layers=2, # Few layers for speed - dropout=0.0, - condition_dim=64, - ).to(device) - - protein_batch = create_dummy_protein_batch(batch_size, device) - ligand_batch = create_dummy_ligand_batch(batch_size, device) - - # Create dummy time condition [B, condition_dim] - time_condition = torch.randn(batch_size, 64, device=device) - - print(f"\nJointProteinLigandNetwork forward pass:") - print(f" Protein nodes: {protein_batch.num_nodes}") - print(f" Ligand nodes: {ligand_batch.num_nodes}") - - velocity = network(protein_batch, ligand_batch, time_condition=time_condition) - - print(f" velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])") - - assert velocity.shape == (ligand_batch.num_nodes, 3) - print(" PASSED") - - -def test_full_model(): - """Test ProteinLigandFlowMatchingJoint end-to-end.""" - from src.models.flowmatching import ProteinLigandFlowMatchingJoint - - device = 'cpu' - batch_size = 2 - - model = ProteinLigandFlowMatchingJoint( - protein_input_scalar_dim=76, - protein_input_vector_dim=31, - protein_input_edge_scalar_dim=39, - protein_input_edge_vector_dim=8, - ligand_input_scalar_dim=122, - ligand_input_edge_scalar_dim=44, - hidden_scalar_dim=32, - hidden_vector_dim=8, - hidden_edge_dim=32, - cross_edge_distance_cutoff=15.0, - cross_edge_max_neighbors=8, - cross_edge_num_rbf=16, - intra_edge_distance_cutoff=15.0, - intra_edge_max_neighbors=8, - joint_num_layers=4, - hidden_dim=64, - dropout=0.0, - use_esm_embeddings=False, - ).to(device) - - protein_batch = create_dummy_protein_batch(batch_size, device) - ligand_batch = create_dummy_ligand_batch(batch_size, device) - t = torch.rand(batch_size, device=device) - - print(f"\nProteinLigandFlowMatchingJoint full forward pass:") - print(f" Protein nodes: {protein_batch.num_nodes}") - print(f" Ligand nodes: {ligand_batch.num_nodes}") - print(f" Batch size: {batch_size}") - - velocity = model(protein_batch, ligand_batch, t) - - print(f" Velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])") - assert velocity.shape == (ligand_batch.num_nodes, 3) - - # Test backward pass - loss = velocity.pow(2).mean() - loss.backward() - print(f" Loss: {loss.item():.6f}") - print(f" Backward pass: OK") - - # Count parameters - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f" Trainable parameters: {total_params:,}") - print(" PASSED") - - -def test_model_builder(): - """Test model builder with joint architecture config.""" - from src.utils.model_builder import build_model - - config = { - 'architecture': 'joint', - 'protein_input_scalar_dim': 76, - 'protein_input_vector_dim': 31, - 'protein_input_edge_scalar_dim': 39, - 'protein_input_edge_vector_dim': 8, - 'ligand_input_scalar_dim': 122, - 'ligand_input_edge_scalar_dim': 44, - 'hidden_scalar_dim': 32, - 'hidden_vector_dim': 8, - 'hidden_edge_dim': 32, - 'cross_edge_distance_cutoff': 15.0, - 'cross_edge_max_neighbors': 8, - 'cross_edge_num_rbf': 16, - 'intra_edge_distance_cutoff': 15.0, - 'intra_edge_max_neighbors': 8, - 'joint_num_layers': 4, - 'hidden_dim': 64, - 'dropout': 0.0, - 'use_esm_embeddings': False, - } - - device = 'cpu' - model = build_model(config, device) - - print(f"\nModel builder test:") - print(f" Model type: {type(model).__name__}") - assert type(model).__name__ == 'ProteinLigandFlowMatchingJoint' - print(" PASSED") - - -def test_esm_integration(): - """Test ESM gated concatenation integration.""" - from src.models.flowmatching import ProteinLigandFlowMatchingJoint - - device = 'cpu' - batch_size = 2 - - model = ProteinLigandFlowMatchingJoint( - protein_input_scalar_dim=76, - protein_input_vector_dim=31, - protein_input_edge_scalar_dim=39, - protein_input_edge_vector_dim=8, - ligand_input_scalar_dim=122, - ligand_input_edge_scalar_dim=44, - hidden_scalar_dim=32, - hidden_vector_dim=8, - hidden_edge_dim=32, - cross_edge_distance_cutoff=15.0, - cross_edge_max_neighbors=8, - cross_edge_num_rbf=16, - intra_edge_distance_cutoff=15.0, - intra_edge_max_neighbors=8, - joint_num_layers=4, - hidden_dim=64, - dropout=0.0, - use_esm_embeddings=True, - esmc_dim=1152, - esm3_dim=1536, - esm_proj_dim=64, # Small for testing - ).to(device) - - protein_batch = create_dummy_protein_batch(batch_size, device) - ligand_batch = create_dummy_ligand_batch(batch_size, device) - t = torch.rand(batch_size, device=device) - - # Add dummy ESM embeddings - n_protein = protein_batch.num_nodes - protein_batch.esmc_embeddings = torch.randn(n_protein, 1152, device=device) - protein_batch.esm3_embeddings = torch.randn(n_protein, 1536, device=device) - - print(f"\nESM gated concatenation test:") - print(f" Protein nodes: {n_protein}") - print(f" ESM proj dim: 64") - print(f" Effective protein scalar dim: 76 + 64 = 140") - - velocity = model(protein_batch, ligand_batch, t) - - print(f" Velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])") - assert velocity.shape == (ligand_batch.num_nodes, 3) - - # Test backward - loss = velocity.pow(2).mean() - loss.backward() - print(f" Backward pass: OK") - - # Verify ESM gate is in computation graph (gradients assigned) - gate_has_grad = all(p.grad is not None for p in model.esm_gate.parameters()) - print(f" ESM gate gradients: {'OK' if gate_has_grad else 'MISSING'}") - assert gate_has_grad - - # Test without ESM embeddings (should zero-pad) - model.zero_grad() - protein_batch2 = create_dummy_protein_batch(batch_size, device) - velocity2 = model(protein_batch2, ligand_batch, t) - assert velocity2.shape == (ligand_batch.num_nodes, 3) - print(f" Without ESM (zero-pad): OK") - - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f" Trainable parameters: {total_params:,}") - print(" PASSED") - - -if __name__ == "__main__": - print("=" * 60) - print("Testing Joint Graph Architecture") - print("=" * 60) - - test_build_cross_edges() - test_build_intra_edges() - test_joint_network() - test_full_model() - test_esm_integration() - test_model_builder() - - print("\n" + "=" * 60) - print("All tests PASSED!") - print("=" * 60) diff --git a/train_torsion.py b/train_torsion.py new file mode 100644 index 0000000..42427af --- /dev/null +++ b/train_torsion.py @@ -0,0 +1,465 @@ +#!/usr/bin/env python +""" +Training script for SE(3) + Torsion Decomposition Flow Matching. + +Predicts translation [3] + rotation [3] + torsion [M] instead of per-atom velocity [N, 3]. +""" + +import torch +import torch.nn as nn +import numpy as np +from tqdm import tqdm +import yaml +import argparse +from datetime import datetime +from torch.utils.data import DataLoader +import wandb + +from src.data.dataset_torsion import FlowFixTorsionDataset, collate_torsion_batch +from src.models.flowmatching_torsion import ProteinLigandFlowMatchingTorsion +from src.utils.losses_torsion import compute_se3_torsion_loss +from src.utils.sampling_torsion import sample_trajectory_torsion +from src.utils.sampling import generate_timestep_schedule +from src.utils.training_utils import build_optimizer_and_scheduler +from src.utils.early_stop import EarlyStopping +from src.utils.utils import set_random_seed +from src.utils.experiment import ExperimentManager +from src.utils.wandb_logger import ( + WandBLogger, + extract_module_gradient_norms, + extract_parameter_stats, +) + + +def build_torsion_model(model_config, device): + """Build ProteinLigandFlowMatchingTorsion from config.""" + model = ProteinLigandFlowMatchingTorsion( + protein_input_scalar_dim=model_config.get('protein_input_scalar_dim', 76), + protein_input_vector_dim=model_config.get('protein_input_vector_dim', 31), + protein_input_edge_scalar_dim=model_config.get('protein_input_edge_scalar_dim', 39), + protein_input_edge_vector_dim=model_config.get('protein_input_edge_vector_dim', 8), + protein_hidden_scalar_dim=model_config.get('protein_hidden_scalar_dim', 128), + protein_hidden_vector_dim=model_config.get('protein_hidden_vector_dim', 32), + protein_output_scalar_dim=model_config.get('protein_output_scalar_dim', 128), + protein_output_vector_dim=model_config.get('protein_output_vector_dim', 32), + protein_num_layers=model_config.get('protein_num_layers', 3), + ligand_input_scalar_dim=model_config.get('ligand_input_scalar_dim', 121), + ligand_input_edge_scalar_dim=model_config.get('ligand_input_edge_scalar_dim', 44), + ligand_hidden_scalar_dim=model_config.get('ligand_hidden_scalar_dim', 128), + ligand_hidden_vector_dim=model_config.get('ligand_hidden_vector_dim', 16), + ligand_output_scalar_dim=model_config.get('ligand_output_scalar_dim', 128), + ligand_output_vector_dim=model_config.get('ligand_output_vector_dim', 16), + ligand_num_layers=model_config.get('ligand_num_layers', 3), + interaction_num_heads=model_config.get('interaction_num_heads', 8), + interaction_num_layers=model_config.get('interaction_num_layers', 2), + interaction_num_rbf=model_config.get('interaction_num_rbf', 32), + interaction_pair_dim=model_config.get('interaction_pair_dim', 64), + velocity_hidden_scalar_dim=model_config.get('velocity_hidden_scalar_dim', 128), + velocity_hidden_vector_dim=model_config.get('velocity_hidden_vector_dim', 16), + velocity_num_layers=model_config.get('velocity_num_layers', 4), + hidden_dim=model_config.get('hidden_dim', 256), + dropout=model_config.get('dropout', 0.1), + use_esm_embeddings=model_config.get('use_esm_embeddings', True), + esmc_dim=model_config.get('esmc_dim', 1152), + esm3_dim=model_config.get('esm3_dim', 1536), + ).to(device) + return model + + +class FlowFixTorsionTrainer: + """Trainer for SE(3) + Torsion decomposition flow matching.""" + + def __init__(self, config): + self.config = config + self.device = torch.device(config['device']) + + set_random_seed(config.get('seed', 42)) + + self.setup_experiment() + self.setup_data() + self.setup_model() + self.setup_optimizer() + self.setup_early_stopping() + + self.global_step = 0 + self.current_epoch = 0 + self.best_val_success = 0.0 + + self.setup_wandb() + self.wandb_logger = WandBLogger(enabled=self.wandb_enabled) + + def setup_experiment(self): + """Setup experiment manager.""" + wandb_config = self.config.get('wandb', {}) + run_name = wandb_config.get('name') + if run_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_name = f"torsion_{timestamp}" + + base_dir = self.config.get('experiment', {}).get('base_dir', 'save') + self.exp_manager = ExperimentManager(base_dir=base_dir, run_name=run_name, config=self.config) + self.checkpoint_dir = self.exp_manager.checkpoints_dir + self.exp_manager.logger.info(f"Experiment: {run_name} | Device: {self.device}") + + def setup_data(self): + """Setup torsion-aware datasets.""" + data_config = self.config['data'] + training_config = self.config['training'] + + self.train_dataset = FlowFixTorsionDataset( + data_dir=data_config.get('data_dir', 'train_data'), + split_file=data_config.get('split_file'), + split='train', + max_samples=data_config.get('max_train_samples'), + seed=self.config.get('seed', 42), + loading_mode=data_config.get('loading_mode', 'lazy'), + ) + self.val_dataset = FlowFixTorsionDataset( + data_dir=data_config.get('data_dir', 'train_data'), + split_file=data_config.get('split_file'), + split='valid', + max_samples=data_config.get('max_val_samples'), + seed=self.config.get('seed', 42), + loading_mode=data_config.get('loading_mode', 'lazy'), + ) + + self.train_loader = DataLoader( + self.train_dataset, + batch_size=training_config['batch_size'], + shuffle=True, + num_workers=data_config.get('num_workers', 4), + collate_fn=collate_torsion_batch, + ) + self.val_loader = DataLoader( + self.val_dataset, + batch_size=training_config.get('val_batch_size', 4), + shuffle=False, + num_workers=data_config.get('num_workers', 4), + collate_fn=collate_torsion_batch, + ) + + self.exp_manager.logger.info( + f"Train: {len(self.train_dataset)} | Val: {len(self.val_dataset)} PDBs" + ) + + def setup_model(self): + """Initialize torsion model.""" + self.model = build_torsion_model(self.config['model'], self.device) + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.exp_manager.logger.info(f"Model: {total_params:,} trainable params (SE(3)+Torsion)") + + def setup_optimizer(self): + """Setup optimizer and scheduler.""" + self.optimizer, self.scheduler = build_optimizer_and_scheduler( + self.model, self.config['training'] + ) + + def setup_early_stopping(self): + """Setup early stopping.""" + val_config = self.config['training'].get('validation', {}) + self.early_stopper = EarlyStopping( + mode='min', + patience=val_config.get('early_stopping_patience', 50), + restore_best_weights=True, + save_dir=str(self.checkpoint_dir), + ) + + def setup_wandb(self): + """Setup WandB.""" + wandb_config = self.config.get('wandb', {}) + if not wandb_config.get('enabled', False): + self.wandb_enabled = False + return + + self.wandb_enabled = True + wandb.init( + project=wandb_config.get('project', 'protein-ligand-flowfix'), + entity=wandb_config.get('entity'), + name=self.exp_manager.run_name, + tags=wandb_config.get('tags', []), + dir=self.exp_manager.get_wandb_dir(), + config={ + 'model': self.config['model'], + 'training': self.config['training'], + 'output_mode': 'torsion', + }, + ) + + def train_step(self, batch): + """Single training step for SE(3) + Torsion.""" + ligand_batch = batch['ligand_graph'].to(self.device) + protein_batch = batch['protein_graph'].to(self.device) + coords_x0 = batch['ligand_coords_x0'].to(self.device) + coords_x1 = batch['ligand_coords_x1'].to(self.device) + batch_size = len(batch['pdb_ids']) + + # Sample timestep + t = torch.rand(batch_size, device=self.device) + + # Interpolate + t_expanded = t[ligand_batch.batch].unsqueeze(-1) + x_t = (1 - t_expanded) * coords_x0 + t_expanded * coords_x1 + + ligand_batch_t = ligand_batch.clone() + ligand_batch_t.pos = x_t + + # Torsion data + torsion_data = batch.get('torsion_data') + if torsion_data is not None: + target = { + 'translation': torsion_data['translation'].to(self.device), + 'rotation': torsion_data['rotation'].to(self.device), + 'torsion_changes': torsion_data['torsion_changes'].to(self.device), + } + rotatable_edges = torsion_data['rotatable_edges'].to(self.device) + mask_rotate = torsion_data['mask_rotate'].to(self.device) + else: + # Fallback: rigid body only + from src.data.ligand_feat import compute_rigid_transform + translations, rotations = [], [] + for b in range(batch_size): + mol_mask = (ligand_batch.batch == b) + trans, rot = compute_rigid_transform(coords_x0[mol_mask].cpu(), coords_x1[mol_mask].cpu()) + translations.append(trans) + rotations.append(rot) + target = { + 'translation': torch.stack(translations).to(self.device), + 'rotation': torch.stack(rotations).to(self.device), + 'torsion_changes': torch.zeros(0, device=self.device), + } + rotatable_edges = torch.zeros(0, 2, dtype=torch.long, device=self.device) + mask_rotate = torch.zeros(0, coords_x0.shape[0], dtype=torch.bool, device=self.device) + + # Forward + pred = self.model(protein_batch, ligand_batch_t, t, rotatable_edges=rotatable_edges) + + # Loss + loss_config = self.config['training'].get('torsion_loss', {}) + losses = compute_se3_torsion_loss( + pred=pred, target=target, + coords_x0=coords_x0, coords_x1=coords_x1, + mask_rotate=mask_rotate, rotatable_edges=rotatable_edges, + batch_indices=ligand_batch.batch, + w_trans=loss_config.get('w_trans', 1.0), + w_rot=loss_config.get('w_rot', 1.0), + w_tor=loss_config.get('w_tor', 1.0), + w_coord=loss_config.get('w_coord', 0.5), + ) + + loss = losses['total'] + + # Backward + grad_accum = self.config['training'].get('gradient_accumulation_steps', 1) + (loss / grad_accum).backward() + + if (self.global_step + 1) % grad_accum == 0: + clip_val = self.config['training'].get('gradient_clip') + if clip_val: + nn.utils.clip_grad_norm_(self.model.parameters(), clip_val) + self.optimizer.step() + self.optimizer.zero_grad() + + with torch.no_grad(): + rmsd = torch.sqrt(torch.mean((x_t - coords_x1) ** 2)) + + return { + 'loss': loss.item(), + 'rmsd': rmsd.item(), + 'loss_trans': losses['translation'].item(), + 'loss_rot': losses['rotation'].item(), + 'loss_tor': losses['torsion'].item(), + 'loss_coord': losses['coord_recon'].item(), + } + + @torch.no_grad() + def validate(self): + """Validation with ODE sampling in torsion space.""" + self.model.eval() + + all_rmsds = [] + all_initial_rmsds = [] + + num_steps = self.config['sampling'].get('num_steps', 20) + schedule = self.config['sampling'].get('schedule', 'uniform') + timesteps = generate_timestep_schedule(num_steps, schedule, self.device) + + for batch in tqdm(self.val_loader, desc="Validation"): + ligand_batch = batch['ligand_graph'].to(self.device) + protein_batch = batch['protein_graph'].to(self.device) + coords_x0 = batch['ligand_coords_x0'].to(self.device) + coords_x1 = batch['ligand_coords_x1'].to(self.device) + + # Initial RMSD + init_rmsd = torch.sqrt(torch.mean((coords_x0 - coords_x1) ** 2, dim=-1)) + all_initial_rmsds.extend(init_rmsd.cpu().numpy()) + + # Torsion data + torsion_data = batch.get('torsion_data') + if torsion_data is not None: + rot_edges = torsion_data['rotatable_edges'].to(self.device) + mask_rot = torsion_data['mask_rotate'].to(self.device) + else: + rot_edges = torch.zeros(0, 2, dtype=torch.long, device=self.device) + mask_rot = torch.zeros(0, coords_x0.shape[0], dtype=torch.bool, device=self.device) + + result = sample_trajectory_torsion( + model=self.model, + protein_batch=protein_batch, + ligand_batch=ligand_batch, + x0=coords_x0, + timesteps=timesteps, + rotatable_edges=rot_edges, + mask_rotate=mask_rot, + ) + + refined = result['final_coords'] + per_sample_rmsd = torch.sqrt(torch.mean((refined - coords_x1) ** 2, dim=-1)) + all_rmsds.extend(per_sample_rmsd.cpu().numpy()) + + # Metrics + rmsds = np.array(all_rmsds) + avg_rmsd = rmsds.mean() + avg_init_rmsd = np.mean(all_initial_rmsds) + success_2a = (rmsds < 2.0).mean() * 100 + success_1a = (rmsds < 1.0).mean() * 100 + success_05a = (rmsds < 0.5).mean() * 100 + + print(f"\n Validation Results:") + print(f" Initial RMSD: {avg_init_rmsd:.4f} A") + print(f" Final RMSD: {avg_rmsd:.4f} A") + print(f" Success <2A: {success_2a:.1f}% <1A: {success_1a:.1f}% <0.5A: {success_05a:.1f}%") + + if self.wandb_enabled: + self.wandb_logger.log_validation_epoch( + val_loss=avg_rmsd, val_rmsd=avg_rmsd, + val_rmsd_initial=avg_init_rmsd, val_rmsd_final=avg_rmsd, + success_rate_2a=success_2a, success_rate_1a=success_1a, + success_rate_05a=success_05a, epoch=self.current_epoch, + ) + + # Early stopping + val_metrics = { + 'rmsd': avg_rmsd, 'success_2A': success_2a, + 'success_1A': success_1a, 'success_05A': success_05a, + } + early_stop = self.early_stopper.step( + score=-success_2a, model=self.model, + optimizer=self.optimizer, scheduler=self.scheduler, + epoch=self.current_epoch, valid_metrics=val_metrics, + ) + + if success_2a > self.best_val_success: + self.best_val_success = success_2a + + self.model.train() + return avg_rmsd, early_stop + + def save_checkpoint(self, filename): + """Save checkpoint.""" + torch.save({ + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'global_step': self.global_step, + 'current_epoch': self.current_epoch, + 'best_val_success': self.best_val_success, + 'config': self.config, + }, self.checkpoint_dir / filename) + + def train(self): + """Main training loop.""" + num_epochs = self.config['training']['num_epochs'] + val_freq = self.config['training'].get('validation', {}).get('frequency', 20) + + for epoch in range(num_epochs): + self.current_epoch = epoch + + if hasattr(self.train_dataset, 'set_epoch'): + self.train_dataset.set_epoch(epoch) + + self.model.train() + epoch_losses = [] + + pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{num_epochs}") + for batch in pbar: + metrics = self.train_step(batch) + epoch_losses.append(metrics['loss']) + self.global_step += 1 + + pbar.set_postfix({ + 'loss': f"{metrics['loss']:.4f}", + 'tr': f"{metrics['loss_trans']:.3f}", + 'rot': f"{metrics['loss_rot']:.3f}", + 'tor': f"{metrics['loss_tor']:.3f}", + }) + + if self.wandb_enabled and self.global_step % 10 == 0: + self.wandb_logger.log({ + 'train/loss': metrics['loss'], + 'train/loss_trans': metrics['loss_trans'], + 'train/loss_rot': metrics['loss_rot'], + 'train/loss_tor': metrics['loss_tor'], + 'train/loss_coord': metrics['loss_coord'], + 'train/rmsd': metrics['rmsd'], + 'train/lr': self.optimizer.param_groups[0]['lr'], + 'meta/epoch': epoch, + 'meta/step': self.global_step, + }) + + # Validation + early_stop = False + if epoch > 0 and epoch % val_freq == 0: + _, early_stop = self.validate() + + if early_stop: + print("\n Training stopped early") + break + + # Checkpoints + save_freq = self.config['checkpoint'].get('save_freq', 10) + if epoch % save_freq == 0: + self.save_checkpoint(f'epoch_{epoch:04d}.pt') + if self.config['checkpoint'].get('save_latest', True): + self.save_checkpoint('latest.pt') + + # Scheduler + if self.scheduler: + self.scheduler.step() + + # Epoch summary + avg_loss = np.mean(epoch_losses) + lr = self.optimizer.param_groups[0]['lr'] + print(f" Epoch {epoch}: loss={avg_loss:.4f} lr={lr:.6f} " + f"early_stop={self.early_stopper.counter}/{self.early_stopper.patience}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--resume', type=str, default=None) + parser.add_argument('--device', type=str, default=None) + args = parser.parse_args() + + with open(args.config, 'r') as f: + config = yaml.safe_load(f) + + if args.device: + config['device'] = args.device + + trainer = FlowFixTorsionTrainer(config) + + if args.resume: + ckpt = torch.load(args.resume, weights_only=False) + trainer.model.load_state_dict(ckpt['model_state_dict']) + trainer.optimizer.load_state_dict(ckpt['optimizer_state_dict']) + trainer.global_step = ckpt['global_step'] + print(f"Resumed from step {trainer.global_step}") + + try: + trainer.train() + finally: + if trainer.wandb_enabled: + wandb.finish() + + +if __name__ == '__main__': + main()