diff --git a/README.md b/README.md
index 509cce6..7dc518c 100644
--- a/README.md
+++ b/README.md
@@ -81,14 +81,14 @@ pip install -r requirements.txt
### Training
```bash
-python train.py --config configs/train_joint.yaml
+python train.py --config configs/train.yaml
```
### Multi-GPU Training (DDP)
```bash
python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=8 \
- train.py --config configs/train_joint.yaml
+ train.py --config configs/train.yaml
```
### Inference
diff --git a/configs/README.md b/configs/README.md
index 5027e3f..7543574 100644
--- a/configs/README.md
+++ b/configs/README.md
@@ -4,11 +4,15 @@
## 파일 목록
-- `train_joint.yaml`
- - **학습 설정** (joint protein-ligand graph 아키텍처)
- - Slurm 스크립트 `scripts/slurm/run_train_joint*.sh`에서 기본으로 사용합니다.
+- `train.yaml`
+ - **Cartesian 학습 설정** (per-atom velocity field)
+ - Slurm 스크립트 `scripts/slurm/run_train_full.sh`에서 기본으로 사용합니다.
- 추론 시에도 이 설정을 기반으로 모델을 로드합니다.
+- `train_torsion.yaml`
+ - **SE(3) + Torsion 학습 설정** (translation + rotation + torsion)
+ - `train_torsion.py`에서 사용합니다.
+
## 공통 구조(개요)
설정 파일은 아래 섹션을 갖습니다.
@@ -34,21 +38,21 @@
### 학습 (로컬)
```bash
-python train.py --config configs/train_joint.yaml
+python train.py --config configs/train.yaml
```
### 멀티 GPU 학습 (DDP, 로컬/서버)
```bash
python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=8 \
- train.py --config configs/train_joint.yaml
+ train.py --config configs/train.yaml
```
### 추론/평가
```bash
python inference.py \
- --config configs/train_joint.yaml \
+ --config configs/train.yaml \
--checkpoint /path/to/checkpoint.pt \
--device cuda
```
diff --git a/configs/overfit_test.yaml b/configs/overfit_test.yaml
deleted file mode 100644
index 1f87db6..0000000
--- a/configs/overfit_test.yaml
+++ /dev/null
@@ -1,116 +0,0 @@
-# FlowFix Overfitting Test Configuration
-# Based on successful model: flowfix_20251104_151853
-# Uses ONLY simple flow matching loss (no auxiliary losses)
-
-device: cuda
-seed: 42
-
-# Data - Small dataset for overfitting test (Dynamic Pose Sampling)
-data:
- data_dir: train_data # ✅ Relative path to current directory
- split_file: null
- max_train_samples: 50 # Small dataset for overfitting test
- max_val_samples: 5 # Validation samples
- num_workers: 4
- # Note: Each epoch samples different poses per PDB for augmentation
-
-# Model dimensions (same as successful model)
-model:
- # Protein network
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- protein_hidden_scalar_dim: 64
- protein_hidden_vector_dim: 16
- protein_output_scalar_dim: 64
- protein_output_vector_dim: 16
- protein_num_layers: 4
-
- # Ligand network
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- ligand_hidden_scalar_dim: 64
- ligand_hidden_vector_dim: 16
- ligand_output_scalar_dim: 64
- ligand_output_vector_dim: 16
- ligand_num_layers: 4
-
- # Interaction network (CRITICAL: 4 layers, not 8!)
- interaction_hidden_dim: 256
- interaction_num_heads: 8
- interaction_num_rbf: 32 # Rich distance encoding
- interaction_pair_dim: 64 # Expressive pair bias
- interaction_num_layers: 4 # ✅ Matches successful model (13M params)
-
- # Velocity predictor
- velocity_hidden_scalar_dim: 64
- velocity_hidden_vector_dim: 16
- velocity_num_layers: 4
-
- # Regularization
- dropout: 0.1 # ✅ Enable dropout
-
-# Training - Settings from successful model
-training:
- num_epochs: 500
- batch_size: 2
- num_timesteps_per_sample: 16 # Sample 16 timesteps per system
- val_batch_size: 8
- learning_rate: 0.001 # 1e-3 for fast convergence on small data
- gradient_clip: 100.0
-
- optimizer:
- type: adam
- weight_decay: 0.0 # No regularization for overfitting
- betas: [0.9, 0.999]
- eps: 1.0e-08
-
- scheduler:
- # CosineAnnealingWarmRestarts
- eta_max: 0.01
- T_0: 10 # Restart every 10 epochs
- T_mult: 1 # Same period
- T_up: 2 # Warmup for 2 epochs
- gamma: 0.95 # Decay by 5% each restart
-
- validation:
- frequency: 150 # ✅ Validate every 10 epochs
- save_best: true
- early_stopping_patience: 100
-
-# Sampling (for validation)
-sampling:
- num_steps: 50
- method: euler
-
-# Experiment management
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 10
- save_latest: true
- keep_last_n: 10
-
-# Visualization
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-# WandB logging
-wandb:
- enabled: true
- project: "protein-ligand-flowfix-overfit"
- entity: null
- name: null # Auto-generated
- tags: ["overfit-test", "simple-flow-matching", "5-samples"]
- log_gradients: true
- log_model_weights: true
- log_animations: true
- log_parameter_histograms: true
- log_parameter_evolution: true
- log_layer_analysis: true
diff --git a/configs/train_fast_overfit.yaml b/configs/train_fast_overfit.yaml
deleted file mode 100644
index e6aa93c..0000000
--- a/configs/train_fast_overfit.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Fast Overfit Test: Sink Flow, Small Batch, High LR
-data:
- data_dir: train_data
- split_file: train_data/splits_overfit_32.json
- max_train_samples: null
- max_val_samples: 4
- fix_pose: true
- fix_pose_high_rmsd: true
- position_noise_scale: 0.05 # Reduced noise for faster overfitting
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
- num_timesteps_per_sample: 1
-
-model:
- architecture: joint
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
- hidden_dim: 256
- esm_proj_dim: 128
- self_conditioning: false
- dropout: 0.0
-
-training:
- num_epochs: 300
- batch_size: 4 # 8 updates per epoch (32/4)
- val_batch_size: 4
- gradient_clip: 10.0 # Tighter clip for stability with sink flow
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: false
- compile: false
-
- optimizer:
- type: muon
- muon:
- lr: 0.05 # Higher LR for fast overfit
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.005
- adamw:
- lr: 0.001 # Higher AdamW LR too
- weight_decay: 0.0
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.0001
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 20
- early_stopping_patience: 1000
-
-sampling:
- num_steps: 20
- method: euler
- schedule: uniform
- t_end: 1.0
-
-experiment:
- base_dir: save
-
-checkpoint:
- save_freq: 20
- save_latest: true
- keep_last_n: 3
-
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- name: "overfit-fast-sink"
- tags: ["overfit-test", "sink-flow", "fast-overfit"]
- log_gradients: false
- log_model_weights: false
- log_animations: true
diff --git a/configs/train_joint.yaml b/configs/train_joint.yaml
deleted file mode 100644
index ad8d3df..0000000
--- a/configs/train_joint.yaml
+++ /dev/null
@@ -1,143 +0,0 @@
-# FlowFix Joint Graph Architecture Training Configuration
-# Uses joint protein-ligand graph with cross-edges instead of separate encoders + attention
-
-device: cuda
-seed: 42
-
-# Data
-data:
- data_dir: train_data
- split_file: train_data/splits.json
- max_train_samples: null
- max_val_samples: null
- num_workers: 8
- loading_mode: lazy
- # Timestep sampling for flow matching
- timestep_sampling:
- method: uniform # uniform for pose refinement (t=0 is structure, not noise)
- mu: 0.8 # unused for uniform
- sigma: 1.7 # unused for uniform
- num_timesteps_per_sample: 1 # K=1 for debugging (was 4)
-
-# Model - Joint Graph Architecture
-model:
- architecture: joint # "separate" (default) or "joint" (new)
-
- # Protein input dims (unchanged)
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
-
- # Ligand input dims (unchanged)
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
-
- # Joint network hidden dims (increased from 64/16 for better capacity)
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6 # Unified: message passing + velocity prediction in one network
-
- # Cross-edge parameters (protein-ligand)
- cross_edge_distance_cutoff: 6.0 # Angstroms
- cross_edge_max_neighbors: 16 # Max neighbors per node (KNN cap within cutoff)
- cross_edge_num_rbf: 32 # RBF features for cross-edge distances
-
- # Intra-edge parameters (dynamic within protein/ligand, supplements pre-computed edges)
- intra_edge_distance_cutoff: 6.0 # Angstroms
- intra_edge_max_neighbors: 16 # Max neighbors per node
-
- # Time conditioning hidden dim
- hidden_dim: 256
-
- # ESM embedding integration
- esm_proj_dim: 128 # ESM projection dim (gated concatenation to protein features)
-
- # Self-conditioning: condition on predicted x1 (crystal) from first pass (50% of training)
- # Standard in protein diffusion (RFdiffusion, FrameDiff, FrameFlow)
- self_conditioning: true
-
- dropout: 0.1
-
-# Training
-training:
- num_epochs: 1000
- batch_size: 128 # Per-GPU batch size (8x GPU = 1024 effective)
- val_batch_size: 128
- gradient_clip: 100.0
- gradient_accumulation_steps: 1 # batch already large enough (256 × 8 = 2048)
- distance_geometry_weight: 0.1
- ema:
- enabled: true
- decay: 0.999
- compile: true # torch_cluster functions are excluded via @torch._dynamo.disable
- compile_options:
- mode: default
- dynamic: true
-
- protein_ligand_clash:
- enabled: false
- ca_threshold: 3.0
- sc_threshold: 2.5
- margin: 1.0
- weight: 0.2
-
- optimizer:
- type: muon
- muon:
- lr: 0.005 # Reduced from 0.02 to stabilize training
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.0005 # Cosine decay target (10% of peak)
- adamw:
- lr: 0.0003
- weight_decay: 0.01
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.00003 # Cosine decay target (10% of peak)
-
- # Unified LR schedule: linear warmup + plateau + cosine decay
- schedule:
- warmup_fraction: 0.05 # 5% of total epochs (0 -> max LR)
- plateau_fraction: 0.80 # 80% of total epochs (keep max LR)
- decay_fraction: 0.15 # 15% of total epochs (cosine max -> min LR)
- warmup_epochs: 5 # Fallback for legacy behavior if fractions are not used
-
- validation:
- frequency: 20
- early_stopping_patience: 30
-
-# Sampling (Inference/Validation)
-sampling:
- num_steps: 20
- method: euler
- schedule: quadratic
-
-# Experiment management
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 10
- save_latest: true
- keep_last_n: 5
-
-# Visualization
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-# WandB logging
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- entity: null
- name: null
- tags: ["flow-matching", "protein-ligand", "se3-equivariant", "joint-graph"]
- log_gradients: true
- log_model_weights: true
- log_animations: true
diff --git a/configs/train_joint_test.yaml b/configs/train_joint_test.yaml
deleted file mode 100644
index 8155ba5..0000000
--- a/configs/train_joint_test.yaml
+++ /dev/null
@@ -1,139 +0,0 @@
-# FlowFix Joint Graph - Test Configuration (Single A5000 GPU)
-# Reduced batch size for 24GB VRAM
-
-device: cuda
-seed: 42
-
-# Data
-data:
- data_dir: train_data
- split_file: train_data/splits.json
- max_train_samples: null
- max_val_samples: null
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
-
-# Model - Joint Graph Architecture
-model:
- architecture: joint
-
- # Protein input dims
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
-
- # Ligand input dims
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
-
- # Joint network hidden dims
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
-
- # Cross-edge parameters
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
-
- # Intra-edge parameters
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
-
- # Time conditioning hidden dim
- hidden_dim: 256
-
- # ESM embedding integration
- esm_proj_dim: 128
-
- # Self-conditioning
- self_conditioning: true
-
- dropout: 0.1
-
-# Training - Reduced for single A5000 (24GB)
-training:
- num_epochs: 100 # Reduced for quick test
- batch_size: 32 # Reduced from 1024 for A5000 (24GB)
- val_batch_size: 32
- gradient_clip: 100.0
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: true
- decay: 0.999
- compile: false # Disabled for debugging
- compile_options:
- mode: default
- dynamic: true
-
- protein_ligand_clash:
- enabled: false
- ca_threshold: 3.0
- sc_threshold: 2.5
- margin: 1.0
- weight: 0.2
-
- optimizer:
- type: muon
- muon:
- lr: 0.02
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.002
- adamw:
- lr: 0.0003
- weight_decay: 0.01
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.00003
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 10 # More frequent validation for test
- early_stopping_patience: 30
-
-# Sampling (Inference/Validation)
-sampling:
- num_steps: 10 # Reduced for faster validation
- method: euler
- schedule: quadratic
-
-# Experiment management
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 10
- save_latest: true
- keep_last_n: 3
-
-# Visualization
-visualization:
- enabled: false # Disabled for quick test
- save_animation: false
- animation_fps: 10
- num_samples: 1
-
-# WandB logging
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- entity: null
- name: null
- tags: ["flow-matching", "protein-ligand", "se3-equivariant", "joint-graph", "test"]
- log_gradients: true
- log_model_weights: true
- log_animations: false
diff --git a/configs/train_overfit_32.yaml b/configs/train_overfit_32.yaml
deleted file mode 100644
index 3e14c04..0000000
--- a/configs/train_overfit_32.yaml
+++ /dev/null
@@ -1,124 +0,0 @@
-# Overfit Test: 32 training samples, val=subset of train
-# Purpose: More stable BatchNorm stats while still testing overfitting
-
-device: cuda
-seed: 42
-
-# Data - 32 train, 4 val (subset of train)
-data:
- data_dir: train_data
- split_file: train_data/splits_overfit_32.json
- max_train_samples: null
- max_val_samples: 4
- fix_pose: true
- fix_pose_high_rmsd: false
- position_noise_scale: 0.1
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
- num_timesteps_per_sample: 1
-
-# Model
-model:
- architecture: joint
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
- hidden_dim: 256
- esm_proj_dim: 128
- self_conditioning: true
- dropout: 0.0
-
-# Training
-training:
- num_epochs: 500
- batch_size: 16 # Larger batch for BatchNorm stability
- val_batch_size: 4
- gradient_clip: 100.0
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: false
- decay: 0.999
- compile: false
-
- protein_ligand_clash:
- enabled: false
- ca_threshold: 3.0
- sc_threshold: 2.5
- margin: 1.0
- weight: 0.2
-
- optimizer:
- type: muon
- muon:
- lr: 0.02
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.002
- adamw:
- lr: 0.0003
- weight_decay: 0.0
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.00003
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 50
- early_stopping_patience: 1000
-
-# Sampling
-sampling:
- num_steps: 20
- method: euler
- schedule: uniform
- t_end: 1.0
-
-# Experiment
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 50
- save_latest: true
- keep_last_n: 3
-
-# Visualization
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-# WandB
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- entity: null
- name: "overfit-test-32-fixed"
- tags: ["overfit-test", "batchnorm-fix", "fixed-edges"]
- log_gradients: false
- log_model_weights: false
- log_animations: true
diff --git a/configs/train_overfit_high_rmsd.yaml b/configs/train_overfit_high_rmsd.yaml
deleted file mode 100644
index 5f1beec..0000000
--- a/configs/train_overfit_high_rmsd.yaml
+++ /dev/null
@@ -1,124 +0,0 @@
-# Overfit Test: 32 training samples, val=subset of train
-# Purpose: More stable BatchNorm stats while still testing overfitting
-
-device: cuda
-seed: 42
-
-# Data - 32 train, 4 val (subset of train)
-data:
- data_dir: train_data
- split_file: train_data/splits_overfit_32.json
- max_train_samples: null
- max_val_samples: 4
- fix_pose: true
- fix_pose_high_rmsd: true
- position_noise_scale: 0.1
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
- num_timesteps_per_sample: 1
-
-# Model
-model:
- architecture: joint
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
- hidden_dim: 256
- esm_proj_dim: 128
- self_conditioning: true
- dropout: 0.0
-
-# Training
-training:
- num_epochs: 500
- batch_size: 16 # Larger batch for BatchNorm stability
- val_batch_size: 4
- gradient_clip: 100.0
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: false
- decay: 0.999
- compile: false
-
- protein_ligand_clash:
- enabled: false
- ca_threshold: 3.0
- sc_threshold: 2.5
- margin: 1.0
- weight: 0.2
-
- optimizer:
- type: muon
- muon:
- lr: 0.02
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.002
- adamw:
- lr: 0.0003
- weight_decay: 0.0
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.00003
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 50
- early_stopping_patience: 1000
-
-# Sampling
-sampling:
- num_steps: 20
- method: euler
- schedule: uniform
- t_end: 1.0
-
-# Experiment
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 50
- save_latest: true
- keep_last_n: 3
-
-# Visualization
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-# WandB
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- entity: null
- name: "overfit-test-high-rmsd"
- tags: ["high-rmsd", "overfit-test", "batchnorm-fix", "fixed-edges"]
- log_gradients: false
- log_model_weights: false
- log_animations: true
diff --git a/configs/train_overfit_test.yaml b/configs/train_overfit_test.yaml
deleted file mode 100644
index cb8073d..0000000
--- a/configs/train_overfit_test.yaml
+++ /dev/null
@@ -1,124 +0,0 @@
-# Overfit Test: Small dataset (32 samples), train=val
-# Purpose: Verify model can learn/overfit to same data
-
-device: cuda
-seed: 42
-
-# Data - tiny overfit split
-data:
- data_dir: train_data
- split_file: train_data/splits_overfit_tiny.json # 8 samples, train=val
- max_train_samples: null
- max_val_samples: 8 # Only validate on 8 samples for speed
- fix_pose: true # Use pose 1 (easier, RMSD ~1A)
- fix_pose_high_rmsd: false # Disable high RMSD poses for now
- position_noise_scale: 0.1 # Add 0.1A noise to x_t for ODE robustness (smaller for stability)
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
- num_timesteps_per_sample: 1
-
-# Model - same as main config
-model:
- architecture: joint
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
- hidden_dim: 256
- esm_proj_dim: 128
- self_conditioning: true # Keep enabled, helps learning
- dropout: 0.0 # No dropout for overfit test
-
-# Training - fast iteration
-training:
- num_epochs: 1000
- batch_size: 8 # All 8 samples in one batch
- val_batch_size: 8
- gradient_clip: 100.0
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: false # Disable EMA for faster iteration
- decay: 0.999
- compile: false # Disable compile for faster startup
-
- protein_ligand_clash:
- enabled: false
- ca_threshold: 3.0
- sc_threshold: 2.5
- margin: 1.0
- weight: 0.2
-
- optimizer:
- type: muon
- muon:
- lr: 0.02 # Muon can use high LR
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.002
- adamw:
- lr: 0.0003
- weight_decay: 0.0 # No weight decay for overfit test
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.00003
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 100 # Validate every 100 epochs for closer monitoring
- early_stopping_patience: 1000 # Don't early stop
-
-# Sampling
-sampling:
- num_steps: 20 # More steps for smoother integration
- method: euler
- schedule: uniform # Linear spacing
- t_end: 1.0 # Full integration t=0~1
-
-# Experiment
-experiment:
- base_dir: save
-
-# Checkpoints
-checkpoint:
- save_freq: 50
- save_latest: true
- keep_last_n: 3
-
-# Visualization
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-# WandB
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- entity: null
- name: "overfit-test-8"
- tags: ["overfit-test", "tiny"]
- log_gradients: false
- log_model_weights: false
- log_animations: true
diff --git a/configs/train_rectified_flow.yaml b/configs/train_rectified_flow.yaml
deleted file mode 100644
index f9a6693..0000000
--- a/configs/train_rectified_flow.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Rectified Flow Test: Stable Linear Interpolation
-data:
- data_dir: train_data
- split_file: train_data/splits_overfit_32.json
- max_train_samples: null
- max_val_samples: 4
- fix_pose: true
- fix_pose_high_rmsd: true
- position_noise_scale: 0.05
- num_workers: 4
- loading_mode: lazy
- timestep_sampling:
- method: uniform
- mu: 0.8
- sigma: 1.7
- num_timesteps_per_sample: 1
-
-model:
- architecture: joint
- protein_input_scalar_dim: 76
- protein_input_vector_dim: 31
- protein_input_edge_scalar_dim: 39
- protein_input_edge_vector_dim: 8
- ligand_input_scalar_dim: 122
- ligand_input_edge_scalar_dim: 44
- hidden_scalar_dim: 128
- hidden_vector_dim: 32
- hidden_edge_dim: 128
- joint_num_layers: 6
- cross_edge_distance_cutoff: 6.0
- cross_edge_max_neighbors: 16
- cross_edge_num_rbf: 32
- intra_edge_distance_cutoff: 6.0
- intra_edge_max_neighbors: 16
- hidden_dim: 256
- esm_proj_dim: 128
- self_conditioning: false
- dropout: 0.0
-
-training:
- num_epochs: 300
- batch_size: 4 # 8 updates per epoch (32/4)
- val_batch_size: 4
- gradient_clip: 1.0 # Standard clip is fine for Rectified Flow
- gradient_accumulation_steps: 1
- distance_geometry_weight: 0.1
- ema:
- enabled: false
- compile: false
-
- optimizer:
- type: muon
- muon:
- lr: 0.05
- momentum: 0.95
- weight_decay: 0.0
- min_lr: 0.005
- adamw:
- lr: 0.001
- weight_decay: 0.0
- betas: [0.9, 0.999]
- eps: 1.0e-08
- min_lr: 0.0001
-
- schedule:
- warmup_fraction: 0.05
- plateau_fraction: 0.80
- decay_fraction: 0.15
- warmup_epochs: 5
-
- validation:
- frequency: 20
- early_stopping_patience: 1000
-
-sampling:
- num_steps: 20
- method: euler
- schedule: uniform
- t_end: 1.0
-
-experiment:
- base_dir: save
-
-checkpoint:
- save_freq: 20
- save_latest: true
- keep_last_n: 3
-
-visualization:
- enabled: true
- save_animation: true
- animation_fps: 10
- num_samples: 1
-
-wandb:
- enabled: true
- project: "protein-ligand-flowfix"
- name: "rectified-flow-stable"
- tags: ["overfit-test", "rectified-flow", "stable"]
- log_gradients: false
- log_model_weights: false
- log_animations: true
diff --git a/configs/train_torsion.yaml b/configs/train_torsion.yaml
new file mode 100644
index 0000000..aa95cdc
--- /dev/null
+++ b/configs/train_torsion.yaml
@@ -0,0 +1,117 @@
+# FlowFix SE(3) + Torsion Decomposition Training Configuration
+# Output: translation [3] + rotation [3] + torsion [M] instead of per-atom velocity [N, 3]
+
+device: cuda
+seed: 42
+
+# Data
+data:
+ data_dir: train_data
+ split_file: train_data/splits.json
+ max_train_samples: null
+ max_val_samples: null
+ num_workers: 8
+ loading_mode: lazy
+ timestep_sampling:
+ method: uniform
+ num_timesteps_per_sample: 1
+
+# Model
+# Uses ProteinLigandFlowMatchingTorsion (separate encoders + interaction + torsion heads)
+# Run with: python train_torsion.py --config configs/train_torsion.yaml
+model:
+
+ # Protein input dims
+ protein_input_scalar_dim: 76
+ protein_input_vector_dim: 31
+ protein_input_edge_scalar_dim: 39
+ protein_input_edge_vector_dim: 8
+
+ # Ligand input dims
+ ligand_input_scalar_dim: 122
+ ligand_input_edge_scalar_dim: 44
+
+ # Hidden dims
+ hidden_scalar_dim: 128
+ hidden_vector_dim: 32
+ hidden_edge_dim: 128
+
+ # Time conditioning
+ hidden_dim: 256
+
+ # ESM embedding
+ esm_proj_dim: 128
+
+ self_conditioning: false
+ dropout: 0.1
+
+# Training
+training:
+ num_epochs: 500
+ batch_size: 32
+ val_batch_size: 32
+ gradient_clip: 10.0
+ gradient_accumulation_steps: 1
+
+ # SE(3) + Torsion loss weights
+ torsion_loss:
+ w_trans: 1.0 # Translation MSE weight
+ w_rot: 1.0 # Rotation MSE weight
+ w_tor: 1.0 # Torsion circular MSE weight
+ w_coord: 0.5 # Coordinate reconstruction weight (end-to-end)
+
+ distance_geometry_weight: 0.0 # Disabled for torsion mode (built-in constraints)
+
+ ema:
+ enabled: true
+ decay: 0.999
+ compile: false # Disable for initial debugging
+
+ optimizer:
+ type: muon
+ muon:
+ lr: 0.005
+ momentum: 0.95
+ weight_decay: 0.0
+ min_lr: 0.0005
+ adamw:
+ lr: 0.0003
+ weight_decay: 0.01
+ betas: [0.9, 0.999]
+ eps: 1.0e-08
+ min_lr: 0.00003
+
+ schedule:
+ warmup_fraction: 0.05
+ plateau_fraction: 0.80
+ decay_fraction: 0.15
+
+ validation:
+ frequency: 20
+ early_stopping_patience: 30
+
+# Sampling
+sampling:
+ num_steps: 20
+ method: euler
+ schedule: uniform
+
+# Experiment
+experiment:
+ base_dir: save
+
+checkpoint:
+ save_freq: 10
+ save_latest: true
+ keep_last_n: 5
+
+visualization:
+ enabled: true
+ save_animation: true
+ animation_fps: 10
+ num_samples: 1
+
+wandb:
+ enabled: true
+ project: "protein-ligand-flowfix"
+ tags: ["flow-matching", "se3-torsion", "decomposition"]
diff --git a/output/cartesian_v4_baseline/README.md b/output/cartesian_v4_baseline/README.md
new file mode 100644
index 0000000..7e97b64
--- /dev/null
+++ b/output/cartesian_v4_baseline/README.md
@@ -0,0 +1,22 @@
+# Cartesian v4 Baseline (rectified-flow-full-v4)
+
+## Model
+- **Architecture**: Joint graph, 6-layer GatingEquivariantLayer
+- **Output**: Per-atom Cartesian velocity [N, 3]
+- **Epochs**: 500
+- **Checkpoint**: `save/rectified-flow-full-v4/checkpoints/latest.pt` (1.1GB)
+
+## Results (Full Validation, 200 PDBs)
+
+| Metric | Before | After | Delta |
+|--------|--------|-------|-------|
+| Mean RMSD | 3.20 A | 2.64 A | -0.56 A |
+| Median RMSD | 3.00 A | 2.22 A | -0.78 A |
+
+## Files
+
+- `train_config.yaml` - Training configuration used
+- `full_validation_results.json` - Full validation inference results (200 PDBs, all poses)
+- `latest_train_valid_5.json` - Small train/valid 5 sample inference
+- `train5_allposes_latest.json` - Train 5 all poses inference
+- Visualization plots: `reports/assets/` (RMSD distribution, scatter, improvement)
diff --git a/inference_results/full_validation_results.json b/output/cartesian_v4_baseline/full_validation_results.json
similarity index 100%
rename from inference_results/full_validation_results.json
rename to output/cartesian_v4_baseline/full_validation_results.json
diff --git a/inference_results/latest_train_valid_5.json b/output/cartesian_v4_baseline/latest_train_valid_5.json
similarity index 100%
rename from inference_results/latest_train_valid_5.json
rename to output/cartesian_v4_baseline/latest_train_valid_5.json
diff --git a/inference_results/train5_allposes_latest.json b/output/cartesian_v4_baseline/train5_allposes_latest.json
similarity index 100%
rename from inference_results/train5_allposes_latest.json
rename to output/cartesian_v4_baseline/train5_allposes_latest.json
diff --git a/configs/train_rectified_flow_full.yaml b/output/cartesian_v4_baseline/train_config.yaml
similarity index 66%
rename from configs/train_rectified_flow_full.yaml
rename to output/cartesian_v4_baseline/train_config.yaml
index 476b606..311b6ba 100644
--- a/configs/train_rectified_flow_full.yaml
+++ b/output/cartesian_v4_baseline/train_config.yaml
@@ -1,9 +1,8 @@
-# Rectified Flow Full Training
data:
data_dir: train_data
split_file: train_data/splits.json
max_train_samples: null
- max_val_samples: null # Use full validation set
+ max_val_samples: null
fix_pose: false
fix_pose_high_rmsd: false
position_noise_scale: 0.05
@@ -14,7 +13,6 @@ data:
mu: 0.8
sigma: 1.7
num_timesteps_per_sample: 2
-
model:
architecture: joint
protein_input_scalar_dim: 76
@@ -35,64 +33,60 @@ model:
hidden_dim: 384
esm_proj_dim: 128
self_conditioning: false
- dropout: 0.1 # Add dropout for generalization
-
+ dropout: 0.1
training:
num_epochs: 500
- batch_size: 32 # Balanced for v4 (1.5x scale)
- val_batch_size: 32 # Increased for faster validation
+ batch_size: 32
+ val_batch_size: 32
gradient_clip: 1.0
gradient_accumulation_steps: 1
distance_geometry_weight: 0.1
ema:
- enabled: true # Enable EMA for stable implementation
+ enabled: true
decay: 0.999
compile: false
-
optimizer:
- type: adamw # Safe default for full run, or muon if tested
+ type: adamw
adamw:
lr: 0.0001
weight_decay: 0.0001
- betas: [0.9, 0.999]
+ betas:
+ - 0.9
+ - 0.999
eps: 1.0e-08
- min_lr: 1.0e-6
-
+ min_lr: 1.0e-06
schedule:
warmup_fraction: 0.05
- plateau_fraction: 0.80
+ plateau_fraction: 0.8
decay_fraction: 0.15
warmup_epochs: 5
-
validation:
- frequency: 20 # Validate every 20 epochs
+ frequency: 20
early_stopping_patience: 50
-
sampling:
num_steps: 20
method: euler
schedule: uniform
t_end: 1.0
-
experiment:
base_dir: save
-
checkpoint:
save_freq: 5
save_latest: true
keep_last_n: 5
-
visualization:
enabled: true
- save_animation: false # Disable animation save to save space/time on full run
+ save_animation: false
animation_fps: 10
num_samples: 4
-
wandb:
enabled: true
- project: "protein-ligand-flowfix"
- name: "rectified-flow-full-v4"
- tags: ["full-train", "rectified-flow", "stable"]
+ project: protein-ligand-flowfix
+ name: rectified-flow-full-v4
+ tags:
+ - full-train
+ - rectified-flow
+ - stable
log_gradients: false
log_model_weights: false
log_animations: true
diff --git a/reports/assets/hist_improvement.png b/reports/assets/hist_improvement.png
deleted file mode 100644
index 0d5a1da..0000000
Binary files a/reports/assets/hist_improvement.png and /dev/null differ
diff --git a/reports/assets/hist_rmsd_distribution.png b/reports/assets/hist_rmsd_distribution.png
deleted file mode 100644
index 4b04ad8..0000000
Binary files a/reports/assets/hist_rmsd_distribution.png and /dev/null differ
diff --git a/reports/assets/scatter_atoms_vs_improvement.png b/reports/assets/scatter_atoms_vs_improvement.png
deleted file mode 100644
index e1bbd6d..0000000
Binary files a/reports/assets/scatter_atoms_vs_improvement.png and /dev/null differ
diff --git a/reports/assets/scatter_init_vs_final_pdb.png b/reports/assets/scatter_init_vs_final_pdb.png
deleted file mode 100644
index 3e04255..0000000
Binary files a/reports/assets/scatter_init_vs_final_pdb.png and /dev/null differ
diff --git a/reports/assets/scatter_init_vs_final_pose.png b/reports/assets/scatter_init_vs_final_pose.png
deleted file mode 100644
index 43d3ba3..0000000
Binary files a/reports/assets/scatter_init_vs_final_pose.png and /dev/null differ
diff --git a/reports/assets/scatter_init_vs_improvement.png b/reports/assets/scatter_init_vs_improvement.png
deleted file mode 100644
index ec541fa..0000000
Binary files a/reports/assets/scatter_init_vs_improvement.png and /dev/null differ
diff --git a/reports/architecture.md b/reports/cartesian/architecture.md
similarity index 93%
rename from reports/architecture.md
rename to reports/cartesian/architecture.md
index b29eb86..5da8a40 100644
--- a/reports/architecture.md
+++ b/reports/cartesian/architecture.md
@@ -167,14 +167,14 @@ flowchart TB
Pre-trained PLM의 residue-level embedding을 protein features에 통합.
```mermaid
-flowchart LR
- ESMC["ESMC 600M
[N, 1152]"] --> P1["MLP
1152->128->128"]
- ESM3["ESM3
[N, 1536]"] --> P2["MLP
1536->128->128"]
+flowchart TB
+ ESMC["ESMC 600M [N, 1152]"] --> P1["MLP 1152->128->128"]
+ ESM3["ESM3 [N, 1536]"] --> P2["MLP 1536->128->128"]
P1 --> WS["Weighted Sum
softmax(esm_weight)"]
P2 --> WS
- WS --> GATE["esm_gate
sigmoid MLP"]
+ WS --> GATE["esm_gate (sigmoid MLP)"]
PX["protein.x [N,76]"] --> CAT["Gated Concat"]
GATE --> CAT
@@ -204,26 +204,23 @@ flowchart LR
### 3.1 Flow Matching Loss
```mermaid
-flowchart LR
+flowchart TB
subgraph sample["Sampling"]
X0["x_0 (docked)"]
X1["x_1 (crystal)"]
T["t ~ Uniform(0,1)"]
end
- INTERP["x_t = (1-t)*x0 + t*x1"]
- FWD["v_pred = model(prot, lig_t, t)"]
- VT["v_true = x1 - x0"]
- MSE["L_flow = MSE(v_pred, v_true)"]
- DG["L_dg = dist geometry
bond constraints
time-weighted"]
- TOTAL["L = L_flow + 0.1 * L_dg"]
-
- X0 --> INTERP
+ X0 --> INTERP["x_t = (1-t)*x0 + t*x1"]
X1 --> INTERP
T --> INTERP
- INTERP --> FWD --> MSE --> TOTAL
- VT --> MSE
- DG --> TOTAL
+
+ INTERP --> FWD["v_pred = model(prot, lig_t, t)"]
+ FWD --> MSE["L_flow = MSE(v_pred, v_true)"]
+ VT["v_true = x1 - x0"] --> MSE
+
+ MSE --> TOTAL["L = L_flow + 0.1 * L_dg"]
+ DG["L_dg = dist geometry
bond constraints, time-weighted"] --> TOTAL
```
### 3.2 Training Configuration
@@ -355,8 +352,8 @@ src/utils/
model_builder.py Config -> model construction
configs/
- train_rectified_flow_full.yaml # v4 config (trained model)
- train_joint.yaml # Joint architecture template
+ train.yaml # Cartesian training config
+ train_torsion.yaml # SE(3) + Torsion training config
train.py Training loop
inference.py Inference script
diff --git a/inference_results/full_validation_results/hist_improvement.png b/reports/cartesian/assets/hist_improvement.png
similarity index 100%
rename from inference_results/full_validation_results/hist_improvement.png
rename to reports/cartesian/assets/hist_improvement.png
diff --git a/inference_results/full_validation_results/hist_rmsd_distribution.png b/reports/cartesian/assets/hist_rmsd_distribution.png
similarity index 100%
rename from inference_results/full_validation_results/hist_rmsd_distribution.png
rename to reports/cartesian/assets/hist_rmsd_distribution.png
diff --git a/reports/assets/refinement_example_1d1p.gif b/reports/cartesian/assets/refinement_example_1d1p.gif
similarity index 100%
rename from reports/assets/refinement_example_1d1p.gif
rename to reports/cartesian/assets/refinement_example_1d1p.gif
diff --git a/inference_results/full_validation_results/scatter_atoms_vs_improvement.png b/reports/cartesian/assets/scatter_atoms_vs_improvement.png
similarity index 100%
rename from inference_results/full_validation_results/scatter_atoms_vs_improvement.png
rename to reports/cartesian/assets/scatter_atoms_vs_improvement.png
diff --git a/inference_results/full_validation_results/scatter_init_vs_final_pdb.png b/reports/cartesian/assets/scatter_init_vs_final_pdb.png
similarity index 100%
rename from inference_results/full_validation_results/scatter_init_vs_final_pdb.png
rename to reports/cartesian/assets/scatter_init_vs_final_pdb.png
diff --git a/inference_results/full_validation_results/scatter_init_vs_final_pose.png b/reports/cartesian/assets/scatter_init_vs_final_pose.png
similarity index 100%
rename from inference_results/full_validation_results/scatter_init_vs_final_pose.png
rename to reports/cartesian/assets/scatter_init_vs_final_pose.png
diff --git a/inference_results/full_validation_results/scatter_init_vs_improvement.png b/reports/cartesian/assets/scatter_init_vs_improvement.png
similarity index 100%
rename from inference_results/full_validation_results/scatter_init_vs_improvement.png
rename to reports/cartesian/assets/scatter_init_vs_improvement.png
diff --git a/reports/cartesian/results.md b/reports/cartesian/results.md
new file mode 100644
index 0000000..94c818d
--- /dev/null
+++ b/reports/cartesian/results.md
@@ -0,0 +1,89 @@
+# Cartesian v4 Baseline Results
+
+> **Model**: `rectified-flow-full-v4` (joint graph, 8-layer GatingEquivariantLayer, ~13M params)
+>
+> **Output**: Per-atom Cartesian velocity [N, 3]
+>
+> **Evaluation**: 200 PDBs, 11,543 poses, 20-step Euler ODE, EMA applied
+
+---
+
+## Summary Metrics
+
+| Metric | Before Refinement | After Refinement | Change |
+|--------|-------------------|------------------|--------|
+| Mean RMSD | 3.20 A | 2.64 A | -0.56 A |
+| Median RMSD | 3.00 A | 2.22 A | -0.78 A |
+| Success rate (<2A) | 30.4% | 44.6% | +14.2%p |
+| Success rate (<1A) | 8.7% | 13.5% | +4.8%p |
+| Success rate (<0.5A) | 0.7% | 1.3% | +0.6%p |
+| Improved poses | - | 75.2% | - |
+
+---
+
+## Visualizations
+
+### RMSD Distribution: Before vs After
+
+
+
+Refinement 후 분포가 전체적으로 왼쪽(낮은 RMSD)으로 이동. Mean 3.20A -> 2.64A.
+
+### Per-Pose: Initial vs Final RMSD
+
+
+
+대각선 아래 = 개선된 pose. **75.2%의 pose가 개선됨.**
+
+### Per-PDB: Average Initial vs Final RMSD
+
+
+
+PDB 단위로 평균하면 **200개 중 178개 (89.0%)가 개선됨.** 대부분의 target에서 일관된 개선.
+
+### RMSD Improvement Distribution
+
+
+
+Mean improvement: 0.56A, Median: 0.25A. 양의 방향(개선)으로 skewed.
+
+### Initial RMSD vs Improvement
+
+
+
+Initial RMSD가 클수록 improvement 폭도 큼. 단, 매우 큰 perturbation (>8A)에서는 효과 감소.
+
+### Ligand Size vs Improvement
+
+
+
+원자 수가 적은 ligand에서 개선 폭이 크고 분산도 큼. 큰 ligand는 상대적으로 안정적이나 개선 폭이 작음.
+
+### Refinement Trajectory Example (PDB: 1d1p)
+
+
+
+4개 시점에서의 refinement 결과. Green = crystal, Red = current, Purple circle = initial docked pose.
+
+---
+
+## Training Configuration
+
+| Parameter | Value |
+|-----------|-------|
+| Architecture | Joint graph (8x GatingEquivariantLayer) |
+| Hidden irreps | `192x0e + 48x1o + 48x1e` (480d) |
+| Edge cutoff (PL cross) | 6.0 A, max 16 neighbors |
+| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) |
+| Schedule | Linear warmup (5%) + Plateau (80%) + Cosine decay (15%) |
+| Loss | Velocity MSE + Distance geometry loss (weight=0.1) |
+| EMA | decay=0.999 |
+| Batch size | 32 |
+| Epochs | 500 |
+| Dropout | 0.1 |
+
+## Data
+
+- **Checkpoint**: `save/rectified-flow-full-v4/checkpoints/latest.pt`
+- **Config**: `output/cartesian_v4_baseline/train_config.yaml`
+- **Inference results**: `output/cartesian_v4_baseline/`
diff --git a/reports/progress.md b/reports/progress.md
index 83e2061..039d8c0 100644
--- a/reports/progress.md
+++ b/reports/progress.md
@@ -6,168 +6,74 @@
---
-## Overview
+## Approach Comparison
-FlowFix는 docking 결과로 얻어진 protein-ligand binding pose를 crystal structure에 가깝게 refinement하는 모델입니다.
-SE(3)-equivariant message passing network 위에서 flow matching으로 velocity field를 학습하여, perturbed pose -> crystal pose로의 ODE trajectory를 생성합니다.
-
-### Key Design Choices
-
-| Component | Choice | Rationale |
-|-----------|--------|-----------|
-| Representation | Joint protein-ligand graph | Cross-edge로 protein context 직접 전달 |
-| Equivariance | cuEquivariance tensor product | SE(3) symmetry 보존, GPU-accelerated |
-| Interaction | Direct message passing (no attention) | 단순하고 효율적인 protein-ligand interaction |
-| Generative model | Flow matching (rectified flow) | Stable training, fast sampling |
-| Protein embedding | ESMC 600M + ESM3 (weighted, gated) | Pre-trained sequence representation |
-| Optimizer | Muon + AdamW hybrid | 2D weight에 Muon, 나머지 AdamW |
+| | Cartesian (v4) | SE(3) + Torsion |
+|--|----------------|-----------------|
+| **Output** | Per-atom velocity [N, 3] | Trans [3] + Rot [3] + Torsion [M] |
+| **Dimension** | 3N (~132D) | 3+3+M (~14D) |
+| **Status** | Trained (500 epochs) | Implemented, not yet trained |
+| **Mean RMSD** | 3.20A -> 2.64A | - |
+| **Success <2A** | 30.4% -> 44.6% | - |
+| **Improved poses** | 75.2% | - |
---
-## Current Results (v4 - Baseline)
+## Cartesian (v4 Baseline)
-**Model**: `rectified-flow-full-v4` (joint graph, 6-layer GatingEquivariantLayer)
-**Evaluation**: 200 PDBs, 11,543 poses, 20-step Euler ODE, EMA applied
+- **Architecture**: Joint graph, 8x GatingEquivariantLayer, ~13M params
+- **Results**: [cartesian/results.md](cartesian/results.md)
+- **Architecture detail**: [cartesian/architecture.md](cartesian/architecture.md)
+- **Inference data**: `output/cartesian_v4_baseline/`
-### Summary Metrics
+### Key Results
-| Metric | Before Refinement | After Refinement | Change |
-|--------|-------------------|------------------|--------|
+| Metric | Before | After | Change |
+|--------|--------|-------|--------|
| Mean RMSD | 3.20 A | 2.64 A | -0.56 A |
| Median RMSD | 3.00 A | 2.22 A | -0.78 A |
| Success rate (<2A) | 30.4% | 44.6% | +14.2%p |
-| Success rate (<1A) | 8.7% | 13.5% | +4.8%p |
-| Success rate (<0.5A) | 0.7% | 1.3% | +0.6%p |
-| Improved poses | - | 75.2% | - |
-
-### RMSD Distribution: Before vs After
-
-
-
-Refinement 후 분포가 전체적으로 왼쪽(낮은 RMSD)으로 이동. Mean 3.20A -> 2.64A.
-
-### Per-Pose: Initial vs Final RMSD
-
-
-
-대각선 아래 = 개선된 pose. **75.2%의 pose가 개선됨.**
-
-### Per-PDB: Average Initial vs Final RMSD
-
-
-
-PDB 단위로 평균하면 **200개 중 178개 (89.0%)가 개선됨.** 대부분의 target에서 일관된 개선.
-
-### RMSD Improvement Distribution
-
-
-Mean improvement: 0.56A, Median: 0.25A. 양의 방향(개선)으로 skewed.
-
-### Initial RMSD vs Improvement
-
-
-
-Initial RMSD가 클수록 improvement 폭도 큼. 단, 매우 큰 perturbation (>8A)에서는 효과 감소.
-
-### Ligand Size vs Improvement
-
-
-
-원자 수가 적은 ligand에서 개선 폭이 크고 분산도 큼. 큰 ligand는 상대적으로 안정적이나 개선 폭이 작음.
-
-### Refinement Trajectory Example (PDB: 1d1p)
+---
-
+## SE(3) + Torsion
-4개 시점에서의 refinement 결과. Green = crystal, Red = current, Purple circle = initial docked pose.
-Velocity field (orange arrows)를 따라 crystal structure 방향으로 이동.
+- **Architecture**: Shared backbone + decomposed output heads (trans/rot/torsion)
+- **Results**: [torsion/results.md](torsion/results.md)
+- **Architecture detail**: [torsion/architecture.md](torsion/architecture.md)
----
+### Expected Benefits
-## Architecture
-
-> 상세 아키텍처 문서: [architecture.md](architecture.md)
-
-```mermaid
-flowchart LR
- subgraph prep["Preprocessing"]
- P["Protein"] --> ESM["ESM Integration"]
- T["time t"] --> TIME["Time Embedding"]
- end
-
- subgraph joint["Joint Graph"]
- ESM --> PPROJ["Protein Proj"]
- L["Ligand"] --> LPROJ["Ligand Proj"]
- PPROJ --> MERGE["Merge + 4 Edge Types"]
- LPROJ --> MERGE
- end
-
- subgraph mp["Message Passing"]
- MERGE --> LAYERS["8x GatingEquivLayer
+ Time AdaLN"]
- TIME --> LAYERS
- end
-
- subgraph out["Output"]
- LAYERS --> EXT["Extract Ligand"]
- EXT --> VEL["EquivariantMLP
-> v [N_l, 3]"]
- end
-
- style prep fill:#e8f5e9,stroke:#2E7D32
- style joint fill:#fff3e0,stroke:#E65100
- style mp fill:#fce4ec,stroke:#C62828
- style out fill:#f3e5f5,stroke:#6A1B9A
-```
-
-**Joint graph architecture:**
-- Protein + ligand를 하나의 그래프로 합쳐서 **direct message passing** (cross-attention 없음)
-- 4 edge types: PP (pre-computed), LL bonds (pre-computed), LL intra (dynamic), PL cross (dynamic, 6.0A cutoff)
-- 8x GatingEquivariantLayer with time conditioning via AdaLN
-- Velocity output: ligand slice 추출 후 EquivariantMLP -> 3D velocity
-
-### Training Setup
-
-| Parameter | Value |
-|-----------|-------|
-| Architecture | Joint graph (8x GatingEquivariantLayer) |
-| Hidden irreps | `192x0e + 48x1o + 48x1e` (480d) |
-| Edge cutoff (PL cross) | 6.0 A, max 16 neighbors |
-| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) |
-| Schedule | Linear warmup (5%) + Plateau (80%) + Cosine decay (15%) |
-| Loss | Velocity MSE + Distance geometry loss (weight=0.1) |
-| EMA | decay=0.999 |
-| Batch size | 32 |
-| Epochs | 500 |
-| Dropout | 0.1 |
+- 3-10x dimension reduction (물리적 자유도만 학습)
+- Bond length/angle 자동 보존
+- Interpretable decomposition (translation vs rotation vs torsion)
---
## TODO / Next Steps
-- [ ] Success rate <2A 목표: 60%+ (현재 44.6%)
-- [ ] Self-conditioning 효과 ablation
-- [ ] Torsion space decomposition 적용 (utilities exist in ligand_feat.py, not yet integrated)
-- [ ] Multi-step refinement (iterative application)
-- [ ] Larger dataset / cross-dataset generalization
-- [ ] Inference speed optimization (fewer ODE steps)
+- [ ] Torsion 모델 학습 및 Cartesian과 비교
+- [ ] Success rate <2A 목표: 60%+
+- [ ] Self-conditioning ablation
+- [ ] Multi-step refinement
+- [ ] Inference speed optimization
---
## Changelog
+### 2026-03-06 - SE(3) + Torsion Implementation
+- Translation [3] + Rotation [3] + Torsion [M] decomposition 구현
+- 별도 model/dataset/loss/sampling/trainer 파일 분리
+- Codebase cleanup (dead files, legacy configs 정리)
+
### 2026-02-18 - v4 Baseline Results
- Full validation on 200 PDBs (11,543 poses)
- 20-step Euler ODE with EMA model
- Mean RMSD: 3.20A -> 2.64A, Success rate <2A: 30.4% -> 44.6%
-### 2026-02 - Joint Graph Architecture (v4, current)
-- Joint protein-ligand graph with 4 edge types (PP, LL, LL intra, PL cross)
-- 8x GatingEquivariantLayer with time AdaLN conditioning
-- cuEquivariance tensor product for SE(3) equivariance
+### 2026-02 - Joint Graph Architecture (v4)
+- Joint protein-ligand graph with 4 edge types
+- 8x GatingEquivariantLayer with time AdaLN
+- cuEquivariance tensor product
- Muon + AdamW hybrid optimizer
-- EMA (decay=0.999) for inference
-
-### 2024-11 - SE(3) + Torsion Decomposition (not used in training)
-- Translation [3D] + Rotation [3D] + Torsion [M] decomposition utilities implemented
-- Not integrated into model/training - current model uses Cartesian velocity
-- Chain-wise ESM embedding support added
diff --git a/reports/torsion/architecture.md b/reports/torsion/architecture.md
new file mode 100644
index 0000000..aaeb32d
--- /dev/null
+++ b/reports/torsion/architecture.md
@@ -0,0 +1,115 @@
+# SE(3) + Torsion Decomposition Architecture
+
+> **Model**: `ProteinLigandFlowMatchingTorsion`
+>
+> **Output**: Translation [3] + Rotation [3] + Torsion [M] (instead of per-atom velocity [N, 3])
+
+---
+
+## 1. Overview
+
+Molecular pose를 SE(3) + Torsion 공간으로 분해하여, Cartesian 3N차원 대신 (3 + 3 + M)차원에서 velocity field를 학습합니다.
+
+- **Translation [3D]**: 분자 중심의 이동
+- **Rotation [3D, SO(3)]**: 분자 전체 회전 (axis-angle)
+- **Torsion [M]**: M개 rotatable bond의 회전각
+
+### Dimension Reduction
+
+일반적인 drug molecule (N=44 atoms, M=8 rotatable bonds):
+- Cartesian: 44 x 3 = **132D**
+- Torsion: 3 + 3 + 8 = **14D** (9.4x reduction)
+
+---
+
+## 2. Architecture
+
+### 2.1 Backbone (Shared with Cartesian)
+
+Encoder, interaction network, velocity blocks는 base `ProteinLigandFlowMatching`과 동일.
+
+```
+ProteinLigandFlowMatchingTorsion (inherits ProteinLigandFlowMatching)
+├── [shared] protein_encoder, ligand_encoder
+├── [shared] cross_attention, velocity_blocks
+├── [new] translation_head: EquivariantMLP -> 1x1o [3D]
+├── [new] rotation_head: EquivariantMLP -> 1x1o [3D]
+└── [new] torsion_head: MLP(2*scalar_dim -> 1) per bond
+```
+
+### 2.2 Output Heads
+
+**Translation & Rotation**: Global pooling (scatter_mean) -> EquivariantMLP -> 1x1o
+
+```
+h_scalar[mol_i] -> scatter_mean -> EquivariantMLP -> translation [3]
+h_scalar[mol_i] -> scatter_mean -> EquivariantMLP -> rotation [3]
+```
+
+**Torsion**: Edge-level prediction (DiffDock approach)
+
+```
+For each rotatable bond (src, dst):
+ cat(h_scalar[src], h_scalar[dst]) -> MLP -> scalar angle
+```
+
+### 2.3 Application Order
+
+Torsion -> Translation -> Rotation (DiffDock convention)
+
+```
+1. Apply torsion angles (Rodrigues rotation per bond)
+2. Translate center of mass
+3. Rotate around center of mass (axis-angle via Rodrigues)
+```
+
+---
+
+## 3. Loss Function
+
+```
+L = w_trans * L_trans + w_rot * L_rot + w_tor * L_tor + w_coord * L_coord
+
+L_trans = MSE(pred_trans, target_trans)
+L_rot = MSE(pred_rot, target_rot)
+L_tor = circular_MSE(pred_tor, target_tor) # atan2(sin, cos) wrapping
+L_coord = MSE(reconstruct(x0, pred), x1) # end-to-end reconstruction
+```
+
+Default weights: `w_trans=1.0, w_rot=1.0, w_tor=1.0, w_coord=0.5`
+
+---
+
+## 4. Training Configuration
+
+| Parameter | Value |
+|-----------|-------|
+| Architecture | Torsion (shared backbone + decomposed heads) |
+| Hidden scalar/vector | 128 / 32 |
+| Optimizer | Muon (lr=0.005) + AdamW (lr=3e-4) |
+| Schedule | Warmup 5% + Plateau 80% + Cosine 15% |
+| Loss | Trans MSE + Rot MSE + Circular Torsion + Coord Recon |
+| EMA | decay=0.999 |
+| Batch size | 32 |
+| Epochs | 500 |
+| ODE sampling | 20-step Euler |
+
+---
+
+## 5. File Map
+
+```
+src/models/flowmatching_torsion.py # ProteinLigandFlowMatchingTorsion
+src/data/dataset_torsion.py # FlowFixTorsionDataset, collate_torsion_batch
+src/utils/losses_torsion.py # compute_se3_torsion_loss, reconstruct_coords
+src/utils/sampling_torsion.py # sample_trajectory_torsion
+train_torsion.py # FlowFixTorsionTrainer
+configs/train_torsion.yaml # Training config
+```
+
+---
+
+## 6. References
+
+- DiffDock (Corso et al., 2023): Product space diffusion on R3 x SO(3) x T^M
+- Torsional Diffusion (Jing et al., 2022): Diffusion on torsion angles
diff --git a/reports/torsion/results.md b/reports/torsion/results.md
new file mode 100644
index 0000000..33f050e
--- /dev/null
+++ b/reports/torsion/results.md
@@ -0,0 +1,33 @@
+# SE(3) + Torsion Results
+
+> **Status**: Not yet trained
+>
+> **Model**: `ProteinLigandFlowMatchingTorsion`
+>
+> **Output**: Translation [3] + Rotation [3] + Torsion [M]
+
+---
+
+## Summary Metrics
+
+| Metric | Before Refinement | After Refinement | Change |
+|--------|-------------------|------------------|--------|
+| Mean RMSD | - | - | - |
+| Median RMSD | - | - | - |
+| Success rate (<2A) | - | - | - |
+| Success rate (<1A) | - | - | - |
+| Improved poses | - | - | - |
+
+> 학습 완료 후 업데이트 예정
+
+---
+
+## Visualizations
+
+> 학습 완료 후 `assets/`에 추가 예정
+
+---
+
+## Training Log
+
+> WandB 링크 및 학습 커브 추가 예정
diff --git a/scripts/README.md b/scripts/README.md
index 8a1c2fb..45b2876 100755
--- a/scripts/README.md
+++ b/scripts/README.md
@@ -15,6 +15,8 @@ scripts/
- `visualize_loss.py`: 학습 로그 기반 loss 시각화
- `visualize_trajectory.py`: 샘플링 trajectory 시각화
+- `infer_full_validation.py`: 전체 validation set 추론
+- `infer_small_train_valid.py`: 소규모 train/valid 추론
## `scripts/data/`
@@ -24,14 +26,11 @@ scripts/
- `generate_test_data.py`: 테스트 데이터 생성
- `inspect_data.py`: 데이터 점검
- `verify_test_data.py`: 테스트 데이터 검증
-- `run_preprocess_pdbbind.sh`: 전처리 실행용 쉘 래퍼
## `scripts/slurm/`
환경(경로, 파티션, GPU 수)은 클러스터마다 다르므로, 본인 환경에 맞게 `PYTHON`, `PROJECT_DIR`, `#SBATCH ...`를 조정해서 사용하세요.
-- `run_train_joint_test.sh`: 단일 GPU 빠른 테스트
-- `run_train_joint.sh`: 멀티 GPU(DDP) 학습
+- `run_train_full.sh`: 멀티 GPU(DDP) 학습
- `run_inference.sh`: 체크포인트 추론/평가
-- `run_visualize_trajectory.sh`: trajectory 시각화
-
+- `run_full_val_inference.sh`: 전체 validation 추론
diff --git a/scripts/analysis/infer_full_validation.py b/scripts/analysis/infer_full_validation.py
index 6e419c2..0311692 100755
--- a/scripts/analysis/infer_full_validation.py
+++ b/scripts/analysis/infer_full_validation.py
@@ -310,7 +310,7 @@ def save_partial(signum, frame):
def main():
p = argparse.ArgumentParser(description="Full validation inference")
- p.add_argument("--config", default="configs/train_joint.yaml")
+ p.add_argument("--config", default="configs/train.yaml")
p.add_argument("--checkpoint", default="save/rectified-flow-full-v4/checkpoints/latest.pt")
p.add_argument("--output", default="inference_results/full_validation_results.json")
p.add_argument("--no_ema", action="store_true")
diff --git a/scripts/debug_ckpt.py b/scripts/debug_ckpt.py
deleted file mode 100755
index 97b37b4..0000000
--- a/scripts/debug_ckpt.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import os
-import torch
-import yaml
-import argparse
-from pathlib import Path
-from torch_geometric.loader import DataLoader
-from src.data.dataset import FlowFixDataset, collate_flowfix_batch
-from src.models.flowmatching import ProteinLigandFlowMatchingJoint
-from src.utils.relaxation import RelaxationEngine
-
-def test_checkpoint(checkpoint_path, config_path, split="val", pdb_id=None, do_relax=False):
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"\n" + "="*60)
- print(f"Testing Checkpoint: {checkpoint_path}")
- print(f"Config: {config_path}")
- print("="*60)
-
- # Load config
- with open(config_path, 'r') as f:
- config = yaml.safe_load(f)
-
- # Load dataset
- dataset = FlowFixDataset(
- data_dir=config['data']['data_dir'],
- split_file=config['data'].get('split_file'),
- split=split,
- max_samples=None,
- seed=42,
- fix_pose=config['data'].get('fix_pose', False),
- fix_pose_high_rmsd=config['data'].get('fix_pose_high_rmsd', False)
- )
-
- # Selection
- if pdb_id:
- try:
- target_idx = dataset.pdb_ids.index(pdb_id)
- sample = dataset[target_idx]
- except ValueError:
- print(f"PDB {pdb_id} not in val split, using first entry")
- sample = dataset[0]
- else:
- sample = dataset[0]
-
- # Model
- import inspect
- sig = inspect.signature(ProteinLigandFlowMatchingJoint.__init__)
- valid_params = sig.parameters.keys()
- model_config = {k: v for k, v in config['model'].items() if k in valid_params}
-
- model = ProteinLigandFlowMatchingJoint(**model_config).to(device)
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
- model.load_state_dict(checkpoint['model_state_dict'], strict=False)
-
- # TEST: Keep in train mode to see if BatchNorm is the culprit
- model.eval()
- model.eval()
- print("WARNING: Running in MODEL.TRAIN() mode to avoid BatchNorm eval mismatch")
-
- # Data
- batch = collate_flowfix_batch([sample])
- ligand_batch = batch['ligand_graph'].to(device)
- protein_batch = batch['protein_graph'].to(device)
- ligand_coords_x0 = batch['ligand_coords_x0'].to(device)
- ligand_coords_x1 = batch['ligand_coords_x1'].to(device)
-
- initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1) ** 2)).item()
- print(f"\nPDB: {batch['pdb_ids'][0]}")
- print(f"Initial RMSD: {initial_rmsd:.4f} Å")
-
- # ODE Integration
- num_steps = 20
- current_coords = ligand_coords_x0.clone()
- timesteps = torch.linspace(0.0, 1.0, num_steps + 1)
-
- print(f"\n{'Step':>4} | {'t':>5} | {'RMSD':>8} | {'|v|':>8} | {'|v_tgt|':>8} | {'Sim':>8}")
- print("-" * 60)
-
- for step in range(num_steps):
- t_current = timesteps[step]
- t_next = timesteps[step+1]
- dt = t_next - t_current
-
- t = torch.ones(1, device=device) * t_current
-
- with torch.no_grad():
- ligand_batch.pos = current_coords
- velocity = model(protein_batch, ligand_batch, t)
-
- # target v in flow matching sense: (x1 - xt) / (1-t)
- eps = 1e-5
- v_target = (ligand_coords_x1 - current_coords) / (1.0 - t_current + eps)
-
- rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item()
- v_norm = torch.norm(velocity, dim=-1).mean().item()
- v_target_norm = torch.norm(v_target, dim=-1).mean().item()
-
- v_flat = velocity.view(-1)
- target_flat = v_target.view(-1)
- sim = torch.nn.functional.cosine_similarity(v_flat, target_flat, dim=0).item()
-
- if step % 5 == 0 or step == num_steps - 1:
- print(f"{step:4d} | {t_current:5.2f} | {rmsd:8.4f} | {v_norm:8.4f} | {v_target_norm:8.4f} | {sim:8.4f}")
-
- current_coords = current_coords + velocity * dt
-
- final_rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item()
- print("-" * 60)
- print(f"Final RMSD: {final_rmsd:.4f} Å")
- print(f"Refinement: {initial_rmsd - final_rmsd:.4f} Å")
-
- if do_relax:
- print("\n" + "="*60)
- print("Running Force Field Relaxation...")
- print("="*60)
-
- relax_engine = RelaxationEngine(
- clash_weight=1.0,
- dg_weight=1.0,
- restraint_weight=20.0,
- lr=1.0,
- max_steps=100
- )
-
- relaxed_coords, metrics = relax_engine.relax(
- ligand_coords=current_coords,
- protein_batch=protein_batch,
- distance_bounds=batch.get('distance_bounds'),
- device=device
- )
-
- relaxed_rmsd = torch.sqrt(torch.mean((relaxed_coords - ligand_coords_x1) ** 2)).item()
- print(f"Relaxed RMSD: {relaxed_rmsd:.4f} Å")
- print(f"Improvement: {final_rmsd - relaxed_rmsd:.4f} Å")
- print(f"Metrics: {metrics}")
- print("-" * 60)
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--ckpt', type=str, required=True)
- parser.add_argument('--config', type=str, required=True)
- parser.add_argument('--split', type=str, default='val')
- parser.add_argument('--pdb', type=str, default=None)
- parser.add_argument('--relax', action='store_true', help='Run force field relaxation')
- args = parser.parse_args()
-
- test_checkpoint(args.ckpt, args.config, args.split, args.pdb, args.relax)
diff --git a/scripts/debug_nan.py b/scripts/debug_nan.py
deleted file mode 100755
index 7cfbe96..0000000
--- a/scripts/debug_nan.py
+++ /dev/null
@@ -1,162 +0,0 @@
-#!/usr/bin/env python
-"""
-Debug script to identify NaN source in validation.
-Tests model forward pass step by step.
-"""
-
-import torch
-import yaml
-import sys
-sys.path.insert(0, '/home/jaemin/project/protein-ligand/pose-refine')
-
-from src.utils.model_builder import build_model
-from src.data.dataset import FlowFixDataset, collate_flowfix_batch
-from torch.utils.data import DataLoader
-
-def check_tensor(name, t):
- """Check tensor for NaN/Inf and print stats."""
- if t is None:
- print(f" {name}: None")
- return
- nan_count = torch.isnan(t).sum().item()
- inf_count = torch.isinf(t).sum().item()
- print(f" {name}: shape={tuple(t.shape)}, nan={nan_count}, inf={inf_count}, "
- f"min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
-
-def main():
- # Load config
- config_path = '/home/jaemin/project/protein-ligand/pose-refine/save/overfit-test-32/config.yaml'
- with open(config_path, 'r') as f:
- config = yaml.safe_load(f)
-
- device = torch.device('cuda:0')
-
- # Load dataset (validation mode - uses x0 positions, not x_t)
- val_dataset = FlowFixDataset(
- data_dir=config['data']['data_dir'],
- split_file=config['data']['split_file'],
- split='val',
- fix_pose=config['data'].get('fix_pose', True),
- fix_pose_high_rmsd=config['data'].get('fix_pose_high_rmsd', False),
- position_noise_scale=0.0, # No noise for validation
- loading_mode='lazy',
- max_samples=1,
- )
-
- val_loader = DataLoader(
- val_dataset,
- batch_size=1,
- shuffle=False,
- collate_fn=collate_flowfix_batch,
- )
-
- # Build model
- model = build_model(config['model'], device)
-
- # Load checkpoint
- checkpoint_path = '/home/jaemin/project/protein-ligand/pose-refine/save/overfit-test-32/checkpoints/latest.pt'
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
- model.load_state_dict(checkpoint['model_state_dict'])
-
- print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
-
- # Test both eval modes: standard eval and eval with BatchNorm in train mode
- for bn_mode in ['eval', 'train']:
- print(f"\n{'='*60}")
- print(f"TESTING WITH BATCHNORM IN {bn_mode.upper()} MODE")
- print(f"{'='*60}")
-
- model.eval()
-
- # Optionally keep BatchNorm in train mode
- if bn_mode == 'train':
- for module in model.modules():
- cls_name = type(module).__name__
- if 'BatchNorm' in cls_name:
- module.train()
-
- test_model(model, val_loader, device)
-
-
-def test_model(model, val_loader, device):
- """Test model forward pass and ODE integration with detailed metrics."""
-
- # Get one batch
- batch = next(iter(val_loader))
-
- # Move to device
- ligand_batch = batch['ligand_graph'].to(device)
- protein_batch = batch['protein_graph'].to(device)
- ligand_coords_x0 = batch['ligand_coords_x0'].to(device)
- ligand_coords_x1 = batch['ligand_coords_x1'].to(device)
-
- print(f"\n=== Input Data Check ===")
- print(f"PDB: {batch['pdb_ids'][0]}")
- check_tensor("ligand_coords_x0", ligand_coords_x0)
- check_tensor("ligand_coords_x1", ligand_coords_x1)
-
- # Initial RMSD
- initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1) ** 2)).item()
- print(f"Initial RMSD: {initial_rmsd:.4f} Å")
-
- # Simulate ODE integration with detailed debugging
- print(f"\n=== Detailed ODE Integration Debug (Euler, 20 steps) ===")
- print(f"{'Step':>4} | {'t':>5} | {'RMSD':>7} | {'|v|':>7} | {'|v_tgt|':>7} | {'Sim':>7}")
- print("-" * 60)
-
- num_steps = 20
- current_coords = ligand_coords_x0.clone()
-
- for step in range(num_steps):
- t_current = step / num_steps
- t_next = (step + 1) / num_steps
- dt = t_next - t_current
-
- t = torch.tensor([t_current], device=device)
- ligand_batch.pos = current_coords.clone()
-
- # Self-conditioning (if enabled in model, usually x1_self_cond)
- # Note: In actual validation, we might need to handle self_conditioning properly
- # but for simple velocity direction check, t=0 is most critical.
-
- with torch.no_grad():
- # In validation, we might use x_t directly for self_cond
- velocity = model(protein_batch, ligand_batch, t)
-
- # Ideal velocity at this point to reach x1
- # In CFM with linear interpolation, the constant velocity is (x1 - x0)
- # However, our target direction at current point is (x1 - xt)
- v_target = ligand_coords_x1 - current_coords
-
- # Metrics
- rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item()
- v_norm = torch.norm(velocity, dim=-1).mean().item()
- v_target_norm = torch.norm(v_target, dim=-1).mean().item()
-
- # Cosine similarity between predicted velocity and target direction
- v_flat = velocity.view(-1)
- target_flat = v_target.view(-1)
- sim = torch.nn.functional.cosine_similarity(v_flat, target_flat, dim=0).item()
-
- print(f"{step:4d} | {t_current:5.2f} | {rmsd:7.4f} | {v_norm:7.4f} | {v_target_norm:7.4f} | {sim:7.4f}")
-
- # Clip velocity (same as in validation)
- velocity_norm = torch.norm(velocity, dim=-1, keepdim=True)
- max_velocity = 5.0
- scale = torch.clamp(max_velocity / (velocity_norm + 1e-8), max=1.0)
- velocity = velocity * scale
-
- # Update coords
- current_coords = current_coords + dt * velocity
-
- if torch.isnan(current_coords).any():
- print(f"!!! Error: NaN detected at step {step}")
- break
-
- final_rmsd = torch.sqrt(torch.mean((current_coords - ligand_coords_x1) ** 2)).item()
- print("-" * 45)
- print(f"Final RMSD: {final_rmsd:.4f} Å")
- print(f"Refinement: {initial_rmsd - final_rmsd:+.4f} Å")
-
-if __name__ == '__main__':
- main()
diff --git a/scripts/debug_velocity.py b/scripts/debug_velocity.py
deleted file mode 100755
index c3daf39..0000000
--- a/scripts/debug_velocity.py
+++ /dev/null
@@ -1,193 +0,0 @@
-#!/usr/bin/env python
-"""
-Debug script to check model velocity predictions.
-"""
-import os
-import sys
-import torch
-import yaml
-import numpy as np
-
-# Add project root to path
-sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-from src.data.dataset import FlowFixDataset, collate_flowfix_batch
-from src.utils.model_builder import build_model
-
-
-def main():
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Device: {device}")
-
- # Load checkpoint from overfit-test-8
- checkpoint_path = 'save/overfit-test-8/checkpoints/latest.pt'
- if not os.path.exists(checkpoint_path):
- print(f"No checkpoint found at {checkpoint_path}")
- return
-
- print(f"Loading checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
-
- # Load config from checkpoint if available, otherwise from file
- if 'config' in checkpoint:
- config = checkpoint['config']
- print("Using config from checkpoint")
- else:
- config_path = 'configs/train_joint.yaml'
- with open(config_path) as f:
- config = yaml.safe_load(f)
- print(f"Using config from {config_path}")
-
- # Build model
- model = build_model(config['model'], device)
-
- # Load weights
- state_dict = checkpoint.get('model_state_dict', checkpoint)
- model.load_state_dict(state_dict, strict=False)
- model = model.to(device)
- model.eval()
-
- print(f"Model loaded, epoch {checkpoint.get('epoch', 'unknown')}")
-
- # Create dataset
- dataset = FlowFixDataset(
- data_dir='train_data',
- split_file='train_data/splits_overfit_tiny.json',
- split='train', # Use train split to get fixed pose
- seed=42,
- fix_pose=True, # Use pose 1 (easier)
- )
-
- # Get one sample
- sample = dataset[0]
- batch = collate_flowfix_batch([sample])
-
- pdb_id = batch['pdb_ids'][0]
- protein_batch = batch['protein_graph'].to(device)
- ligand_batch = batch['ligand_graph'].to(device)
- ligand_coords_x0 = batch['ligand_coords_x0'].to(device)
- ligand_coords_x1 = batch['ligand_coords_x1'].to(device)
- t = batch['t'].to(device)
-
- print(f"\nPDB: {pdb_id}")
- print(f"Num ligand atoms: {ligand_coords_x0.shape[0]}")
- print(f"t: {t.item():.4f}")
-
- # True velocity
- true_velocity = ligand_coords_x1 - ligand_coords_x0
- true_velocity_norm = torch.norm(true_velocity, dim=-1).mean()
-
- print(f"\n=== True Velocity ===")
- print(f" Mean norm: {true_velocity_norm.item():.4f} Å")
- print(f" First 5 atoms:\n{true_velocity[:5]}")
-
- # Predicted velocity at t=0
- with torch.no_grad():
- ligand_batch.pos = ligand_coords_x0.clone()
- t_zero = torch.zeros(1, device=device)
- pred_velocity_t0 = model(protein_batch, ligand_batch, t_zero)
-
- pred_velocity_t0_norm = torch.norm(pred_velocity_t0, dim=-1).mean()
-
- print(f"\n=== Predicted Velocity (t=0) ===")
- print(f" Mean norm: {pred_velocity_t0_norm.item():.4f} Å")
- print(f" First 5 atoms:\n{pred_velocity_t0[:5]}")
-
- # Cosine similarity
- cos_sim = torch.nn.functional.cosine_similarity(
- pred_velocity_t0.flatten(),
- true_velocity.flatten(),
- dim=0
- )
- print(f"\n=== Comparison ===")
- print(f" Cosine similarity: {cos_sim.item():.4f}")
- print(f" Pred/True norm ratio: {pred_velocity_t0_norm.item() / true_velocity_norm.item():.4f}")
-
- # Check if velocity is near zero
- if pred_velocity_t0_norm.item() < 0.1:
- print(f"\n⚠️ WARNING: Predicted velocity is near zero!")
-
- # Predicted velocity at various t
- print(f"\n=== Velocity at different timesteps ===")
- for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
- t_test = torch.tensor([t_val], device=device)
- x_t = (1 - t_val) * ligand_coords_x0 + t_val * ligand_coords_x1
- ligand_batch.pos = x_t.clone()
-
- with torch.no_grad():
- pred_v = model(protein_batch, ligand_batch, t_test)
-
- pred_norm = torch.norm(pred_v, dim=-1).mean()
- cos = torch.nn.functional.cosine_similarity(
- pred_v.flatten(),
- true_velocity.flatten(),
- dim=0
- )
- print(f" t={t_val:.2f}: ||v||={pred_norm.item():.4f}, cos_sim={cos.item():.4f}")
-
- # Initial and final RMSD
- initial_rmsd = torch.sqrt(torch.mean((ligand_coords_x0 - ligand_coords_x1)**2))
- print(f"\n=== RMSD ===")
- print(f" Initial (docked vs crystal): {initial_rmsd.item():.4f} Å")
-
- # Simulate one step
- dt = 0.05
- new_coords = ligand_coords_x0 + dt * pred_velocity_t0
- new_rmsd = torch.sqrt(torch.mean((new_coords - ligand_coords_x1)**2))
- print(f" After one step (dt={dt}): {new_rmsd.item():.4f} Å")
-
- # Check for NaN/Inf in velocity
- print(f"\n=== NaN/Inf Check ===")
- print(f" Velocity has NaN: {torch.isnan(pred_velocity_t0).any().item()}")
- print(f" Velocity has Inf: {torch.isinf(pred_velocity_t0).any().item()}")
- print(f" Velocity max: {pred_velocity_t0.abs().max().item():.4f}")
-
- # Full ODE simulation
- print(f"\n=== Full ODE Simulation (10 steps) ===")
- current = ligand_coords_x0.clone()
- for step in range(10):
- t_val = step / 10
- t_test = torch.tensor([t_val], device=device)
- ligand_batch.pos = current.clone()
- with torch.no_grad():
- v = model(protein_batch, ligand_batch, t_test)
- dt = 0.1
- current = current + dt * v
- rmsd = torch.sqrt(torch.mean((current - ligand_coords_x1)**2)).item()
- v_norm = torch.norm(v, dim=-1).mean().item()
- v_max = torch.abs(v).max().item()
- has_nan = torch.isnan(current).any().item()
- print(f" Step {step}: t={t_val:.2f}, RMSD={rmsd:.4f}A, ||v||={v_norm:.4f}, max|v|={v_max:.4f}, NaN={has_nan}")
- if has_nan:
- print(" NaN detected, stopping")
- break
-
- # Test at t=0 vs t=0.5 to see if model generalizes
- print(f"\n=== Model behavior at different timesteps (same position) ===")
- ligand_batch.pos = ligand_coords_x0.clone()
- for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
- t_test = torch.tensor([t_val], device=device)
- with torch.no_grad():
- v = model(protein_batch, ligand_batch, t_test)
- v_norm = torch.norm(v, dim=-1).mean().item()
- v_max = torch.abs(v).max().item()
- cos_sim = torch.nn.functional.cosine_similarity(v.flatten(), true_velocity.flatten(), dim=0).item()
- print(f" t={t_val:.2f}: ||v||={v_norm:.4f}, max|v|={v_max:.4f}, cos_sim={cos_sim:.4f}")
-
- # Test at EXACT training position x_t
- print(f"\n=== Test at EXACT training position x_t (should match training) ===")
- for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
- # Compute x_t = (1-t)*x0 + t*x1
- x_t = (1 - t_val) * ligand_coords_x0 + t_val * ligand_coords_x1
- ligand_batch.pos = x_t.clone()
- t_test = torch.tensor([t_val], device=device)
- with torch.no_grad():
- v = model(protein_batch, ligand_batch, t_test)
- v_norm = torch.norm(v, dim=-1).mean().item()
- cos_sim = torch.nn.functional.cosine_similarity(v.flatten(), true_velocity.flatten(), dim=0).item()
- mse = torch.mean((v - true_velocity)**2).item()
- print(f" t={t_val:.2f}: ||v||={v_norm:.4f}, cos_sim={cos_sim:.4f}, MSE={mse:.4f}")
-
-
-if __name__ == '__main__':
- main()
diff --git a/scripts/slurm/run_debug.sh b/scripts/slurm/run_debug.sh
deleted file mode 100755
index 47db03d..0000000
--- a/scripts/slurm/run_debug.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=debug-velocity
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=4
-#SBATCH --mem=32G
-#SBATCH --time=00:10:00
-#SBATCH --output=logs/debug_%j.out
-#SBATCH --error=logs/debug_%j.out
-#SBATCH --exclude=gpu3
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-PYTHONPATH=/home/jaemin/project/protein-ligand/pose-refine python scripts/debug_velocity.py
diff --git a/scripts/slurm/run_debug_ckpts.sh b/scripts/slurm/run_debug_ckpts.sh
deleted file mode 100755
index ae40600..0000000
--- a/scripts/slurm/run_debug_ckpts.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=debug-train
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=32G
-#SBATCH --time=00:30:00
-#SBATCH --output=logs/debug_train_%j.out
-#SBATCH --error=logs/debug_train_%j.out
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-export PYTHONPATH=$PYTHONPATH:.
-
-echo "============================================================"
-echo "Analyzing Job 203 (32-Sample Fixed) - Training Sample 10gs"
-echo "============================================================"
-python scripts/debug_ckpt.py --ckpt save/overfit-test-32-fixed/checkpoints/latest.pt --config configs/train_overfit_32.yaml --split train --pdb 10gs
-
-echo -e "\n\n============================================================"
-echo "Analyzing Job 226 (Rectified Flow) - Training Sample 10gs (Epoch 20)"
-echo "============================================================"
-python scripts/debug_ckpt.py --ckpt save/overfit-fast-sink/checkpoints/epoch_0020.pt --config configs/train_rectified_flow.yaml --split train --pdb 10gs
diff --git a/scripts/slurm/run_debug_detailed.sh b/scripts/slurm/run_debug_detailed.sh
deleted file mode 100755
index 1447d20..0000000
--- a/scripts/slurm/run_debug_detailed.sh
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=debug-ode
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=32G
-#SBATCH --time=01:00:00
-#SBATCH --output=logs/debug_ode_%j.out
-#SBATCH --error=logs/debug_ode_%j.out
-
-# Detailed ODE analysis script
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-python scripts/debug_nan.py
diff --git a/scripts/slurm/run_debug_nan.sh b/scripts/slurm/run_debug_nan.sh
deleted file mode 100755
index 8eb6313..0000000
--- a/scripts/slurm/run_debug_nan.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=debug-nan
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=4
-#SBATCH --mem=32G
-#SBATCH --time=00:30:00
-#SBATCH --output=logs/debug_nan_%j.out
-#SBATCH --error=logs/debug_nan_%j.out
-#SBATCH --exclude=gpu3
-
-# Debug NaN issue in validation
-# Tests the BatchNorm fix by running validation on checkpoint
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run debug script
-python scripts/debug_nan.py
diff --git a/scripts/slurm/run_debug_relax.sh b/scripts/slurm/run_debug_relax.sh
deleted file mode 100755
index 193beda..0000000
--- a/scripts/slurm/run_debug_relax.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=debug-relax
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=4
-#SBATCH --mem=32G
-#SBATCH --time=00:30:00
-#SBATCH --output=logs/debug_relax_%j.out
-#SBATCH --error=logs/debug_relax_%j.out
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-PYTHONPATH=/home/jaemin/project/protein-ligand/pose-refine python scripts/debug_ckpt.py --ckpt save/overfit-fast-sink/checkpoints/latest.pt --config configs/train_rectified_flow.yaml --split train --pdb 10gs --relax
diff --git a/scripts/slurm/run_fast_overfit.sh b/scripts/slurm/run_fast_overfit.sh
deleted file mode 100755
index 546cac0..0000000
--- a/scripts/slurm/run_fast_overfit.sh
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=overfit-fast-sink
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=04:00:00
-#SBATCH --output=logs/overfit_fast_sink_%j.out
-#SBATCH --error=logs/overfit_fast_sink_%j.out
-#SBATCH --exclude=gpu3
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run training
-python train.py --config configs/train_fast_overfit.yaml
diff --git a/scripts/slurm/run_full_val_inference.sh b/scripts/slurm/run_full_val_inference.sh
index 66789d2..deb6e37 100755
--- a/scripts/slurm/run_full_val_inference.sh
+++ b/scripts/slurm/run_full_val_inference.sh
@@ -42,7 +42,7 @@ fi
echo "Starting full validation inference (200 val PDBs, all poses)..."
$PYTHON -u scripts/analysis/infer_full_validation.py \
- --config configs/train_joint.yaml \
+ --config configs/train.yaml \
--checkpoint "$CHECKPOINT" \
--output "$OUTPUT"
diff --git a/scripts/slurm/run_overfit_32.sh b/scripts/slurm/run_overfit_32.sh
deleted file mode 100755
index 78b2b57..0000000
--- a/scripts/slurm/run_overfit_32.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=overfit-32
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=04:00:00
-#SBATCH --output=logs/overfit_32_%j.out
-#SBATCH --error=logs/overfit_32_%j.out
-#SBATCH --exclude=gpu3
-
-# Overfit test with 32 samples
-# More data for stable BatchNorm stats, val=subset of train
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run training
-python train.py --config configs/train_overfit_32.yaml
diff --git a/scripts/slurm/run_overfit_32_fixed.sh b/scripts/slurm/run_overfit_32_fixed.sh
deleted file mode 100755
index 7c4c9ad..0000000
--- a/scripts/slurm/run_overfit_32_fixed.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=overfit-32-fix
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=04:00:00
-#SBATCH --output=logs/overfit_32_fix_%j.out
-#SBATCH --error=logs/overfit_32_fix_%j.out
-#SBATCH --exclude=gpu3
-
-# Overfit test with 32 samples
-# More data for stable BatchNorm stats, val=subset of train
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run training
-python train.py --config configs/train_overfit_32.yaml
diff --git a/scripts/slurm/run_overfit_high_rmsd.sh b/scripts/slurm/run_overfit_high_rmsd.sh
deleted file mode 100755
index a1f0ef0..0000000
--- a/scripts/slurm/run_overfit_high_rmsd.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=overfit-high-rmsd
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=04:00:00
-#SBATCH --output=logs/overfit_high_rmsd_%j.out
-#SBATCH --error=logs/overfit_high_rmsd_%j.out
-#SBATCH --exclude=gpu3
-
-# Overfit test with high RMSD poses (32 samples)
-# Testing if the model can handle large refinements
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run training
-python train.py --config configs/train_overfit_high_rmsd.yaml
diff --git a/scripts/slurm/run_overfit_quick.sh b/scripts/slurm/run_overfit_quick.sh
deleted file mode 100755
index 9015f4a..0000000
--- a/scripts/slurm/run_overfit_quick.sh
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=overfit-quick
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=01:00:00
-#SBATCH --output=logs/overfit_quick_%j.out
-#SBATCH --error=logs/overfit_quick_%j.out
-#SBATCH --exclude=gpu3
-
-# Quick overfit test to verify BatchNorm fix
-# Runs shorter training with more frequent validation to test NaN fix
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Run with quick validation (validate every 50 epochs instead of 100)
-python train.py \
- --config configs/train_overfit_test.yaml \
- --name overfit-quick-batchnorm-fix \
- --training.num_epochs 200 \
- --training.validation.frequency 50
diff --git a/scripts/slurm/run_overfit_test.sh b/scripts/slurm/run_overfit_test.sh
deleted file mode 100755
index ddf78f9..0000000
--- a/scripts/slurm/run_overfit_test.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=flowfix-overfit
-#SBATCH --partition=6000ada
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64G
-#SBATCH --time=04:00:00
-#SBATCH --output=logs/slurm_%j.out
-#SBATCH --error=logs/slurm_%j.out
-#SBATCH --exclude=gpu3
-
-# Overfit test: single GPU, small dataset
-
-source /home/jaemin/miniforge3/etc/profile.d/conda.sh
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Enable TF32
-export TORCH_ALLOW_TF32_CUBLAS=1
-
-# Single GPU training
-python train.py --config configs/train_overfit_test.yaml
diff --git a/scripts/slurm/run_test_quick.sh b/scripts/slurm/run_test_quick.sh
deleted file mode 100755
index 942126d..0000000
--- a/scripts/slurm/run_test_quick.sh
+++ /dev/null
@@ -1,106 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=flowfix_test
-#SBATCH --partition=6000ada
-#SBATCH --gres=gpu:1
-#SBATCH --cpus-per-task=4
-#SBATCH --mem=32G
-#SBATCH --time=00:30:00
-#SBATCH --output=logs/slurm_test_%j.out
-#SBATCH --error=logs/slurm_test_%j.out
-
-source ~/.bashrc
-conda activate torch-2.8
-
-cd /home/jaemin/project/protein-ligand/pose-refine
-
-# Quick test with small data
-python -c "
-import torch
-from src.data.dataset import FlowFixDataset, collate_flowfix_batch
-from src.models.flowmatching import ProteinLigandFlowMatchingJoint
-from src.utils.model_builder import build_model
-from torch.utils.data import DataLoader
-import yaml
-
-print('='*60)
-print('Testing Dataset + Model Integration')
-print('='*60)
-
-# Load config
-with open('configs/train_joint.yaml') as f:
- config = yaml.safe_load(f)
-
-# Create small dataset
-print('\\n1. Creating dataset...')
-ds = FlowFixDataset(
- data_dir='train_data',
- split='train',
- max_samples=4,
- cross_edge_cutoff=6.0,
- cross_edge_max_neighbors=16,
- intra_edge_cutoff=6.0,
- intra_edge_max_neighbors=16,
-)
-ds.set_epoch(0)
-print(f' Dataset size: {len(ds)}')
-
-# Create dataloader
-print('\\n2. Creating dataloader...')
-loader = DataLoader(ds, batch_size=2, collate_fn=collate_flowfix_batch, num_workers=0)
-batch = next(iter(loader))
-print(f' Batch keys: {list(batch.keys())}')
-print(f' t shape: {batch[\"t\"].shape}')
-print(f' cross_edge_index shape: {batch[\"cross_edge_index\"].shape}')
-print(f' intra_edge_index shape: {batch[\"intra_edge_index\"].shape}')
-
-# Create model
-print('\\n3. Building model...')
-model = build_model(config['model']).cuda()
-print(f' Model type: {type(model).__name__}')
-print(f' Parameters: {sum(p.numel() for p in model.parameters()):,}')
-
-# Move batch to GPU
-print('\\n4. Moving batch to GPU...')
-protein_batch = batch['protein_graph'].to('cuda')
-ligand_batch = batch['ligand_graph'].to('cuda')
-t = batch['t'].to('cuda')
-cross_edge_index = batch['cross_edge_index'].to('cuda')
-intra_edge_index = batch['intra_edge_index'].to('cuda')
-ligand_coords_x0 = batch['ligand_coords_x0'].to('cuda')
-ligand_coords_x1 = batch['ligand_coords_x1'].to('cuda')
-
-# Forward pass
-print('\\n5. Forward pass...')
-velocity = model(protein_batch, ligand_batch, t, cross_edge_index, intra_edge_index)
-print(f' Velocity shape: {velocity.shape}')
-print(f' Expected: {ligand_batch.pos.shape}')
-
-# Self-conditioning test
-print('\\n6. Self-conditioning test...')
-t_expanded = t[ligand_batch.batch].unsqueeze(-1)
-x1_self_cond = ligand_batch.pos + (1 - t_expanded) * velocity
-velocity2 = model(protein_batch, ligand_batch, t, cross_edge_index, intra_edge_index, x1_self_cond=x1_self_cond)
-print(f' Velocity2 shape: {velocity2.shape}')
-
-# Loss computation
-print('\\n7. Loss computation...')
-true_velocity = ligand_coords_x1 - ligand_coords_x0
-loss = torch.mean((velocity - true_velocity) ** 2)
-print(f' Loss: {loss.item():.6f}')
-
-# Backward pass
-print('\\n8. Backward pass...')
-loss.backward()
-print(' Backward: OK')
-
-# Validation mode (no edges passed - should use fallback)
-print('\\n9. Validation mode (edge fallback)...')
-model.eval()
-with torch.no_grad():
- velocity_val = model(protein_batch, ligand_batch, t) # No edges - uses fallback
-print(f' Velocity shape: {velocity_val.shape}')
-
-print('\\n' + '='*60)
-print('All tests PASSED!')
-print('='*60)
-"
diff --git a/scripts/slurm/run_train_joint.sh b/scripts/slurm/run_train_joint.sh
deleted file mode 100755
index fb6905c..0000000
--- a/scripts/slurm/run_train_joint.sh
+++ /dev/null
@@ -1,95 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=flowfix_joint
-#SBATCH --output=logs/slurm_%j.out
-#SBATCH --partition=6000ada # GPU partition
-#SBATCH --nodes=1 # Single node
-#SBATCH --ntasks-per-node=1 # One task per node
-#SBATCH --cpus-per-task=32 # CPUs for data loading
-#SBATCH --gres=gpu:8 # Request 8 GPUs
-#SBATCH --mem=0 # Request all available memory
-#SBATCH --time=30-00:00:00 # 30 days time limit
-#SBATCH --exclude=gpu3 # Exclude gpu3 (driver mismatch)
-
-# Python path
-PYTHON=/home/jaemin/miniforge3/envs/torch-2.8/bin/python
-
-# Project directory
-PROJECT_DIR=/home/jaemin/project/protein-ligand/pose-refine
-
-# Resume checkpoint (set this to resume training, leave empty for fresh start)
-CHECKPOINT_PATH=""
-
-# Print job info
-echo "=========================================="
-echo "SLURM Job ID: $SLURM_JOB_ID"
-echo "Node: $SLURM_NODELIST"
-echo "Start time: $(date)"
-echo "Working directory: $PROJECT_DIR"
-echo "Config: configs/train_joint.yaml"
-echo "Number of GPUs: $SLURM_GPUS_ON_NODE"
-echo "=========================================="
-
-# Print GPU info
-nvidia-smi
-
-# Check Python and CUDA
-echo ""
-echo "Python: $PYTHON"
-echo "Python version: $($PYTHON --version)"
-echo "PyTorch version: $($PYTHON -c 'import torch; print(torch.__version__)')"
-echo "CUDA available: $($PYTHON -c 'import torch; print(torch.cuda.is_available())')"
-echo "CUDA version: $($PYTHON -c 'import torch; print(torch.version.cuda)')"
-echo "Number of CUDA devices: $($PYTHON -c 'import torch; print(torch.cuda.device_count())')"
-echo ""
-
-# Create logs directory if it doesn't exist
-mkdir -p "$PROJECT_DIR/logs"
-
-# Change to project directory
-cd "$PROJECT_DIR"
-
-# Disable Python output buffering for real-time logs
-export PYTHONUNBUFFERED=1
-
-# Set NCCL environment variables
-export NCCL_DEBUG=WARN
-export NCCL_P2P_DISABLE=1
-export NCCL_IB_DISABLE=1
-
-# Set number of GPUs
-if [ -n "$SLURM_GPUS_ON_NODE" ]; then
- NUM_GPUS=$SLURM_GPUS_ON_NODE
-elif [ -n "$SLURM_GPUS_PER_NODE" ]; then
- NUM_GPUS=$SLURM_GPUS_PER_NODE
-elif [ -n "$SLURM_JOB_GPUS" ]; then
- NUM_GPUS=$(echo $SLURM_JOB_GPUS | awk -F',' '{print NF}')
-else
- NUM_GPUS=8
- echo "Warning: SLURM GPU variables not found, using default NUM_GPUS=8"
-fi
-
-# Build resume argument if checkpoint path is set
-RESUME_ARG=""
-if [ -n "$CHECKPOINT_PATH" ]; then
- RESUME_ARG="--resume $CHECKPOINT_PATH"
- echo "Resuming from checkpoint: $CHECKPOINT_PATH"
-fi
-
-# Run distributed training using torch.distributed.run
-echo "Starting multi-GPU training with $NUM_GPUS GPUs..."
-echo "Command: python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS train.py --config configs/train_joint.yaml $RESUME_ARG"
-echo ""
-
-$PYTHON -m torch.distributed.run \
- --standalone \
- --nnodes=1 \
- --nproc_per_node=$NUM_GPUS \
- train.py \
- --config configs/train_joint.yaml \
- $RESUME_ARG
-
-# Print end time
-echo ""
-echo "=========================================="
-echo "End time: $(date)"
-echo "=========================================="
diff --git a/scripts/slurm/run_train_joint_test.sh b/scripts/slurm/run_train_joint_test.sh
deleted file mode 100755
index 952272a..0000000
--- a/scripts/slurm/run_train_joint_test.sh
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=flowfix_test
-#SBATCH --output=logs/slurm_%j.out
-#SBATCH --partition=test # Test partition (gpu2, 2h limit)
-#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=16
-#SBATCH --gres=gpu:a5000:1 # 1 A5000 GPU
-#SBATCH --mem=32G
-
-# Python path
-PYTHON=/home/jaemin/miniforge3/envs/torch-2.8/bin/python
-
-# Project directory
-PROJECT_DIR=/home/jaemin/project/protein-ligand/pose-refine
-
-# Print job info
-echo "=========================================="
-echo "SLURM Job ID: $SLURM_JOB_ID"
-echo "Node: $SLURM_NODELIST"
-echo "Start time: $(date)"
-echo "Config: configs/train_joint_test.yaml"
-echo "=========================================="
-
-nvidia-smi
-
-echo ""
-echo "Python: $PYTHON"
-$PYTHON --version
-$PYTHON -c "import torch; print(f'PyTorch {torch.__version__}, CUDA avail: {torch.cuda.is_available()}, devices: {torch.cuda.device_count()}')"
-echo ""
-
-mkdir -p "$PROJECT_DIR/logs"
-cd "$PROJECT_DIR"
-
-export PYTHONUNBUFFERED=1
-
-# Single-GPU test run (no distributed)
-echo "Starting single-GPU test run..."
-$PYTHON train.py --config configs/train_joint_test.yaml
-
-echo ""
-echo "=========================================="
-echo "End time: $(date)"
-echo "=========================================="
diff --git a/scripts/slurm/run_visualize_trajectory.sh b/scripts/slurm/run_visualize_trajectory.sh
deleted file mode 100755
index 676ad80..0000000
--- a/scripts/slurm/run_visualize_trajectory.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=flowfix_viz
-#SBATCH --output=logs/viz_%j.out
-#SBATCH --error=logs/viz_%j.err
-#SBATCH --partition=g4090_short,6000ada_short,h100_short
-#SBATCH --nodes=1
-#SBATCH --ntasks=1
-#SBATCH --cpus-per-task=8
-#SBATCH --gres=gpu:1
-
-echo "=============================================="
-echo "FlowFix Trajectory Visualization"
-echo "=============================================="
-echo "Job ID: $SLURM_JOB_ID"
-echo "Node: $SLURM_NODELIST"
-echo "Start time: $(date)"
-echo "=============================================="
-
-cd /home/sim/project/flowfix
-mkdir -p logs
-
-PYTHON=/home/sim/miniconda3/envs/protein-ligand/bin/python
-
-echo ""
-echo "Running trajectory visualization..."
-$PYTHON -u scripts/analysis/visualize_trajectory.py
-
-echo ""
-echo "=============================================="
-echo "Visualization completed!"
-echo "End time: $(date)"
-echo "=============================================="
diff --git a/src/data/dataset_torsion.py b/src/data/dataset_torsion.py
new file mode 100644
index 0000000..9a55a1a
--- /dev/null
+++ b/src/data/dataset_torsion.py
@@ -0,0 +1,206 @@
+"""
+Torsion-aware dataset and collation for SE(3) + Torsion decomposition training.
+
+FlowFixTorsionDataset extends FlowFixDataset to compute and return
+torsion decomposition data (translation, rotation, torsion_changes, mask_rotate).
+"""
+
+import torch
+import numpy as np
+from typing import Optional, List, Dict, Any
+
+from .dataset import FlowFixDataset, collate_flowfix_batch
+
+
+class FlowFixTorsionDataset(FlowFixDataset):
+ """
+ FlowFixDataset with SE(3) + Torsion decomposition.
+
+ Returns additional torsion_data dict with:
+ - translation [3], rotation [3]
+ - torsion_changes [M], rotatable_edges [M, 2], mask_rotate [M, N]
+ """
+
+ def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
+ sample = super().__getitem__(idx)
+ if sample is None:
+ return None
+
+ # Compute torsion decomposition
+ pdb_id = self.pdb_ids[idx]
+ pdb_dir = self.data_dir / pdb_id
+
+ # Reload ligand_data to access edges (parent only returns coords)
+ if self.loading_mode == "preload":
+ ligands_list = self.preloaded_data[pdb_id]['ligands']
+ elif self.loading_mode == "hybrid":
+ ligands_list = torch.load(pdb_dir / "ligands.pt", weights_only=False)
+ else:
+ ligands_list = torch.load(pdb_dir / "ligands.pt", weights_only=False)
+
+ # Use same pose index as parent (deterministic via epoch + idx)
+ rng = np.random.RandomState(self.seed + self.epoch * 10000 + idx)
+ pose_idx = rng.randint(0, len(ligands_list))
+ ligand_data = ligands_list[pose_idx]
+
+ torsion_data = _compute_torsion_data(
+ ligand_data,
+ sample['ligand_coords_x0'],
+ sample['ligand_coords_x1'],
+ )
+ sample['torsion_data'] = torsion_data
+ return sample
+
+
+def _compute_torsion_data(
+ ligand_data: dict,
+ coords_x0: torch.Tensor,
+ coords_x1: torch.Tensor,
+) -> Optional[Dict[str, torch.Tensor]]:
+ """
+ Compute SE(3) + Torsion decomposition from ligand data.
+
+ Returns dict with translation, rotation, torsion_changes, rotatable_edges, mask_rotate.
+ """
+ # Use pre-computed if available
+ if 'torsion_changes' in ligand_data and 'mask_rotate' in ligand_data:
+ return {
+ 'translation': ligand_data.get('translation', torch.zeros(3)),
+ 'rotation': ligand_data.get('rotation', torch.zeros(3)),
+ 'torsion_changes': ligand_data['torsion_changes'],
+ 'rotatable_edges': ligand_data.get('rotatable_edges', torch.zeros(0, 2, dtype=torch.long)),
+ 'mask_rotate': ligand_data['mask_rotate'],
+ }
+
+ # Compute on-the-fly
+ try:
+ from src.data.ligand_feat import compute_rigid_transform, get_transformation_mask
+
+ translation, rotation = compute_rigid_transform(coords_x0, coords_x1)
+
+ # Get edges
+ edges = None
+ if 'edges' in ligand_data:
+ edges = ligand_data['edges']
+ elif 'edge' in ligand_data and 'edges' in ligand_data['edge']:
+ edges = ligand_data['edge']['edges']
+
+ n_atoms = coords_x0.shape[0]
+ empty_result = {
+ 'translation': translation,
+ 'rotation': rotation,
+ 'torsion_changes': torch.zeros(0),
+ 'rotatable_edges': torch.zeros(0, 2, dtype=torch.long),
+ 'mask_rotate': torch.zeros(0, n_atoms, dtype=torch.bool),
+ }
+
+ if edges is None:
+ return empty_result
+
+ mask_rotate, rotatable_edge_indices = get_transformation_mask(edges, n_atoms)
+
+ if len(rotatable_edge_indices) == 0:
+ return empty_result
+
+ edges_np = edges.numpy() if torch.is_tensor(edges) else edges
+ rot_edges = torch.tensor(
+ [[int(edges_np[0, i]), int(edges_np[1, i])] for i in rotatable_edge_indices],
+ dtype=torch.long,
+ )
+ mask_rot_filtered = mask_rotate[rotatable_edge_indices]
+
+ return {
+ 'translation': translation,
+ 'rotation': rotation,
+ 'torsion_changes': torch.zeros(len(rotatable_edge_indices)),
+ 'rotatable_edges': rot_edges,
+ 'mask_rotate': mask_rot_filtered,
+ }
+
+ except Exception:
+ return None
+
+
+def collate_torsion_batch(samples: List[Dict]) -> Dict[str, Any]:
+ """
+ Collate batch with torsion data.
+
+ Wraps collate_flowfix_batch and adds torsion_data collation.
+ """
+ # Filter None
+ samples = [s for s in samples if s is not None]
+ if not samples:
+ raise ValueError("All samples in batch are None!")
+
+ # Base collation (protein, ligand, coords, distance bounds)
+ batch = collate_flowfix_batch(samples)
+
+ # Collate torsion data
+ torsion_list = [s.get('torsion_data', None) for s in samples]
+ batch['torsion_data'] = _collate_torsion_data(torsion_list, batch['ligand_graph'])
+
+ return batch
+
+
+def _collate_torsion_data(
+ torsion_data_list: List[Optional[Dict[str, torch.Tensor]]],
+ ligand_batch,
+) -> Optional[Dict[str, torch.Tensor]]:
+ """
+ Collate torsion data across batch samples.
+
+ Concatenates variable-length rotatable bonds, adjusts atom indices
+ with batch offsets, and expands mask_rotate to full batch size.
+ """
+ valid = [td for td in torsion_data_list if td is not None]
+ if not valid:
+ return None
+
+ total_atoms = ligand_batch.num_nodes
+ atom_counts = torch.bincount(ligand_batch.batch)
+ atom_offsets = torch.cat([torch.zeros(1, dtype=torch.long), atom_counts.cumsum(0)[:-1]])
+
+ translations = []
+ rotations = []
+ torsion_changes = []
+ rotatable_edges = []
+ mask_rotate_list = []
+
+ for i, td in enumerate(torsion_data_list):
+ if td is None:
+ translations.append(torch.zeros(3))
+ rotations.append(torch.zeros(3))
+ continue
+
+ translations.append(td['translation'])
+ rotations.append(td['rotation'])
+
+ if td['torsion_changes'].numel() > 0:
+ torsion_changes.append(td['torsion_changes'])
+
+ offset = atom_offsets[i].item()
+ edges = td['rotatable_edges'].clone() + offset
+ rotatable_edges.append(edges)
+
+ # Expand mask to full batch atom count
+ m_i = td['mask_rotate'].shape[0]
+ n_i = td['mask_rotate'].shape[1]
+ full_mask = torch.zeros(m_i, total_atoms, dtype=torch.bool)
+ full_mask[:, offset:offset + n_i] = td['mask_rotate']
+ mask_rotate_list.append(full_mask)
+
+ result = {
+ 'translation': torch.stack(translations),
+ 'rotation': torch.stack(rotations),
+ }
+
+ if torsion_changes:
+ result['torsion_changes'] = torch.cat(torsion_changes)
+ result['rotatable_edges'] = torch.cat(rotatable_edges)
+ result['mask_rotate'] = torch.cat(mask_rotate_list)
+ else:
+ result['torsion_changes'] = torch.zeros(0)
+ result['rotatable_edges'] = torch.zeros(0, 2, dtype=torch.long)
+ result['mask_rotate'] = torch.zeros(0, total_atoms, dtype=torch.bool)
+
+ return result
diff --git a/src/models/flowmatching_torsion.py b/src/models/flowmatching_torsion.py
new file mode 100644
index 0000000..c8d1e6e
--- /dev/null
+++ b/src/models/flowmatching_torsion.py
@@ -0,0 +1,182 @@
+"""
+SE(3) + Torsion Decomposition Flow Matching Model.
+
+Instead of predicting per-atom Cartesian velocity [N, 3], this model predicts:
+- Translation [B, 3]: CoM displacement
+- Rotation [B, 3]: Axis-angle rotation around CoM
+- Torsion [M]: One scalar per rotatable bond
+
+Inherits the encoder and interaction network from ProteinLigandFlowMatching,
+replacing only the output heads.
+"""
+
+import torch
+import torch.nn as nn
+import cuequivariance as cue_base
+from torch_scatter import scatter_mean
+
+from .cue_layers import EquivariantMLP
+from .torch_layers import MLP
+from .flowmatching import ProteinLigandFlowMatching
+
+
+class ProteinLigandFlowMatchingTorsion(ProteinLigandFlowMatching):
+ """
+ SE(3) + Torsion decomposition variant of ProteinLigandFlowMatching.
+
+ Shares encoder, interaction network, and velocity blocks with the base class.
+ Replaces the output heads with:
+ - Translation head: mean-pool → EquivariantMLP → [B, 3]
+ - Rotation head: mean-pool → EquivariantMLP → [B, 3] (axis-angle)
+ - Torsion head: src/dst node scalar concat → MLP → [M, 1]
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Get dimensions from parent
+ vel_hidden_scalar_dim = kwargs.get('velocity_hidden_scalar_dim', 128)
+ vel_hidden_vector_dim = kwargs.get('velocity_hidden_vector_dim', 16)
+ dropout = kwargs.get('dropout', 0.1)
+
+ self._vel_hidden_scalar_dim = vel_hidden_scalar_dim
+
+ vel_hidden_irreps = cue_base.Irreps(
+ "O3",
+ f"{vel_hidden_scalar_dim}x0e + {vel_hidden_vector_dim}x1o + {vel_hidden_vector_dim}x1e"
+ )
+ intermediate_irreps = cue_base.Irreps(
+ "O3",
+ f"{vel_hidden_scalar_dim}x0e + {vel_hidden_vector_dim}x1o + {vel_hidden_vector_dim}x1e"
+ )
+ vector_output_irreps = cue_base.Irreps("O3", "1x1o")
+
+ # Translation head: pooled features → 3D vector
+ self.translation_output = EquivariantMLP(
+ irreps_in=vel_hidden_irreps,
+ irreps_hidden=intermediate_irreps,
+ irreps_out=vector_output_irreps,
+ num_layers=2,
+ dropout=dropout,
+ )
+
+ # Rotation head: pooled features → 3D axis-angle
+ self.rotation_output = EquivariantMLP(
+ irreps_in=vel_hidden_irreps,
+ irreps_hidden=intermediate_irreps,
+ irreps_out=vector_output_irreps,
+ num_layers=2,
+ dropout=dropout,
+ )
+
+ # Torsion head: src/dst node scalars → 1 scalar per rotatable bond
+ torsion_input_dim = vel_hidden_scalar_dim * 2
+ self.torsion_output = MLP(
+ in_dim=torsion_input_dim,
+ hidden_dim=vel_hidden_scalar_dim,
+ out_dim=1,
+ num_layers=2,
+ activation='silu',
+ )
+
+ # Zero-initialize all output heads for stable training
+ self._zero_init_output_heads()
+
+ # Learnable scales
+ self.translation_scale = nn.Parameter(torch.ones(1) * 0.1)
+ self.rotation_scale = nn.Parameter(torch.ones(1) * 0.1)
+ self.torsion_scale = nn.Parameter(torch.ones(1) * 0.1)
+
+ def _zero_init_output_heads(self):
+ """Zero-initialize output heads for stable training start."""
+ for head in [self.torsion_output]:
+ for module in head.layers:
+ if isinstance(module, nn.Linear):
+ nn.init.zeros_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ def _encode(self, protein_batch, ligand_batch, t):
+ """
+ Run shared encoder + interaction + velocity blocks.
+
+ Returns h: [N_ligand, hidden_irreps] features after message passing.
+ """
+ # ESM embeddings
+ if self.use_esm_embeddings:
+ protein_batch = self._integrate_esm_embeddings(protein_batch)
+
+ protein_output = self.protein_network(protein_batch)
+ ligand_output = self.ligand_network(ligand_batch)
+
+ # Interaction
+ (_, lig_out), (protein_context, _), _ = self.interaction_network(
+ protein_output, ligand_output, protein_batch, ligand_batch
+ )
+
+ # Conditioning
+ protein_context_expanded = protein_context[ligand_batch.batch]
+ combined_condition = torch.cat([protein_context_expanded, lig_out], dim=-1)
+ atom_condition = self.vel_atom_condition_proj(combined_condition)
+
+ # Velocity blocks (shared backbone)
+ h = self.vel_input_projection(ligand_output)
+ h_initial = h
+
+ for block in self.velocity_blocks:
+ h = block(
+ h,
+ ligand_batch.pos,
+ ligand_batch.edge_index,
+ ligand_batch.edge_attr,
+ condition=atom_condition,
+ )
+
+ h = h + h_initial
+ return h
+
+ def forward(
+ self,
+ protein_batch,
+ ligand_batch,
+ t: torch.Tensor,
+ rotatable_edges: torch.Tensor = None,
+ ) -> dict:
+ """
+ Predict SE(3) + Torsion velocity.
+
+ Args:
+ protein_batch: Protein PyG batch
+ ligand_batch: Ligand PyG batch at time t
+ t: [B] timesteps
+ rotatable_edges: [M, 2] atom indices of rotatable bonds
+
+ Returns:
+ Dict with 'translation' [B, 3], 'rotation' [B, 3], 'torsion' [M]
+ """
+ h = self._encode(protein_batch, ligand_batch, t)
+
+ # Translation: mean-pool → 3D
+ h_pooled = scatter_mean(h, ligand_batch.batch, dim=0)
+ translation = self.translation_output(h_pooled) * self.translation_scale
+
+ # Rotation: mean-pool → 3D axis-angle
+ rotation = self.rotation_output(h_pooled) * self.rotation_scale
+
+ # Torsion: src/dst scalar features → 1 scalar per bond
+ if rotatable_edges is not None and rotatable_edges.shape[0] > 0:
+ src_idx = rotatable_edges[:, 0]
+ dst_idx = rotatable_edges[:, 1]
+
+ # Extract scalar part (first scalar_dim components of irreps)
+ h_scalar = h[:, :self._vel_hidden_scalar_dim]
+ edge_feat = torch.cat([h_scalar[src_idx], h_scalar[dst_idx]], dim=-1)
+ torsion = self.torsion_output(edge_feat).squeeze(-1) * self.torsion_scale
+ else:
+ torsion = torch.zeros(0, device=h.device)
+
+ return {
+ 'translation': translation,
+ 'rotation': rotation,
+ 'torsion': torsion,
+ }
diff --git a/src/utils/loss_utils.py b/src/utils/loss_utils.py
deleted file mode 100644
index a24d98a..0000000
--- a/src/utils/loss_utils.py
+++ /dev/null
@@ -1,70 +0,0 @@
-"""
-Simplified loss utilities for FlowFix training.
-Only velocity matching loss with logistic-normal time sampling.
-"""
-
-import torch
-import torch.distributions as dist
-
-
-def compute_flow_matching_loss(pred_velocity, true_velocity, batch_indices):
- """
- Simple uniform flow matching loss: pure MSE on velocity field.
-
- For linear interpolation, true velocity is constant across all t,
- so all timesteps should be equally weighted.
-
- Args:
- pred_velocity: Predicted velocity, shape [N, 3]
- true_velocity: True velocity (x1 - x0), shape [N, 3]
- batch_indices: Which batch each atom belongs to, shape [N] (unused but kept for compatibility)
-
- Returns:
- loss: Uniform MSE loss
- loss_dict: Dictionary with loss components
- """
- # Simple uniform MSE loss (no per-sample normalization, no weighting)
- total_loss = torch.nn.functional.mse_loss(pred_velocity, true_velocity)
-
- loss_dict = {
- 'velocity_loss': total_loss.item(),
- 'total_loss': total_loss.item()
- }
-
- return total_loss, loss_dict
-
-
-def sample_timesteps_logistic_normal(batch_size, device, mu=0.0, sigma=1.0, mix_ratio=0.5):
- """
- Sample timesteps using logistic-normal distribution.
-
- Mixture of:
- - Logistic-normal: sigma(Normal(mu, sigma)) where sigma is sigmoid
- - Uniform: for diversity
-
- Args:
- batch_size: Number of timesteps to sample
- device: torch device
- mu: Mean of underlying normal (logit space)
- sigma: Std of underlying normal (logit space)
- mix_ratio: Fraction to sample from logistic-normal (rest is uniform)
-
- Returns:
- t: Sampled timesteps in [0, 1], shape [batch_size]
- """
- n_logistic = int(batch_size * mix_ratio)
- n_uniform = batch_size - n_logistic
-
- # Logistic-normal samples
- normal_dist = dist.Normal(mu, sigma)
- logit_samples = normal_dist.sample((n_logistic,)).to(device)
- logistic_samples = torch.sigmoid(logit_samples)
-
- # Uniform samples
- uniform_samples = torch.rand(n_uniform, device=device)
-
- # Concatenate and shuffle
- t = torch.cat([logistic_samples, uniform_samples], dim=0)
- t = t[torch.randperm(len(t), device=device)]
-
- return t
diff --git a/src/utils/losses_torsion.py b/src/utils/losses_torsion.py
new file mode 100644
index 0000000..5870a34
--- /dev/null
+++ b/src/utils/losses_torsion.py
@@ -0,0 +1,169 @@
+"""
+Loss functions for SE(3) + Torsion decomposition training.
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+def compute_se3_torsion_loss(
+ pred: dict,
+ target: dict,
+ coords_x0: torch.Tensor,
+ coords_x1: torch.Tensor,
+ mask_rotate: torch.Tensor,
+ rotatable_edges: torch.Tensor,
+ batch_indices: torch.Tensor,
+ w_trans: float = 1.0,
+ w_rot: float = 1.0,
+ w_tor: float = 1.0,
+ w_coord: float = 0.5,
+) -> dict:
+ """
+ Compute SE(3) + Torsion decomposition loss.
+
+ Args:
+ pred: Model output with 'translation' [B,3], 'rotation' [B,3], 'torsion' [M]
+ target: Target with 'translation' [B,3], 'rotation' [B,3], 'torsion_changes' [M]
+ coords_x0: Docked coordinates [N, 3]
+ coords_x1: Crystal coordinates [N, 3]
+ mask_rotate: [M, N] boolean mask
+ rotatable_edges: [M, 2] atom indices
+ batch_indices: [N] batch assignment
+ w_trans, w_rot, w_tor, w_coord: Component weights
+
+ Returns:
+ Dict with 'total', 'translation', 'rotation', 'torsion', 'coord_recon' losses
+ """
+ device = pred['translation'].device
+
+ # Translation MSE
+ loss_trans = F.mse_loss(pred['translation'], target['translation'].to(device))
+
+ # Rotation MSE (axis-angle)
+ loss_rot = F.mse_loss(pred['rotation'], target['rotation'].to(device))
+
+ # Torsion circular MSE
+ if pred['torsion'].numel() > 0 and target['torsion_changes'].numel() > 0:
+ target_tor = target['torsion_changes'].to(device)
+ diff = pred['torsion'] - target_tor
+ diff = torch.atan2(torch.sin(diff), torch.cos(diff)) # wrap to [-pi, pi]
+ loss_tor = (diff ** 2).mean()
+ else:
+ loss_tor = torch.zeros(1, device=device).squeeze()
+
+ # Coordinate reconstruction loss
+ loss_coord = torch.zeros(1, device=device).squeeze()
+ if w_coord > 0:
+ reconstructed = reconstruct_coords(
+ coords_x0, pred, mask_rotate, rotatable_edges, batch_indices
+ )
+ loss_coord = F.mse_loss(reconstructed, coords_x1.to(device))
+
+ total = w_trans * loss_trans + w_rot * loss_rot + w_tor * loss_tor + w_coord * loss_coord
+
+ return {
+ 'total': total,
+ 'translation': loss_trans.detach(),
+ 'rotation': loss_rot.detach(),
+ 'torsion': loss_tor.detach(),
+ 'coord_recon': loss_coord.detach(),
+ }
+
+
+def reconstruct_coords(
+ coords_x0: torch.Tensor,
+ pred: dict,
+ mask_rotate: torch.Tensor,
+ rotatable_edges: torch.Tensor,
+ batch_indices: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Reconstruct coordinates from SE(3) + Torsion prediction.
+
+ Apply order: Torsion → Translation → Rotation.
+ """
+ device = coords_x0.device
+ coords = coords_x0.clone()
+ batch_size = pred['translation'].shape[0]
+
+ for b in range(batch_size):
+ mol_mask = (batch_indices == b)
+ mol_coords = coords[mol_mask]
+ mol_indices = torch.where(mol_mask)[0]
+ n_atoms = mol_indices.shape[0]
+ offset = mol_indices[0].item()
+
+ # 1. Torsion
+ if pred['torsion'].numel() > 0 and mask_rotate.shape[0] > 0:
+ mol_coords = _apply_torsions(
+ mol_coords, pred['torsion'], mask_rotate,
+ rotatable_edges, offset, n_atoms, device
+ )
+
+ # 2. Translation
+ mol_coords = mol_coords + pred['translation'][b]
+
+ # 3. Rotation around CoM
+ mol_coords = _apply_rotation(mol_coords, pred['rotation'][b])
+
+ coords[mol_mask] = mol_coords
+
+ return coords
+
+
+def _apply_torsions(mol_coords, torsion_values, mask_rotate, rotatable_edges, offset, n_atoms, device):
+ """Apply torsion angle changes to molecule coordinates."""
+ for m in range(mask_rotate.shape[0]):
+ angle = torsion_values[m]
+ if angle.abs() < 1e-6:
+ continue
+
+ mask = mask_rotate[m, offset:offset + n_atoms]
+ if not mask.any():
+ continue
+
+ src, dst = rotatable_edges[m]
+ src_l = src.item() - offset
+ dst_l = dst.item() - offset
+ if not (0 <= src_l < n_atoms and 0 <= dst_l < n_atoms):
+ continue
+
+ axis = mol_coords[dst_l] - mol_coords[src_l]
+ axis_norm = axis.norm()
+ if axis_norm < 1e-6:
+ continue
+ axis = axis / axis_norm
+
+ mol_coords = _rodrigues_rotate(mol_coords, mask, axis, mol_coords[dst_l], angle)
+
+ return mol_coords
+
+
+def _rodrigues_rotate(coords, mask, axis, pivot, angle):
+ """Rodrigues rotation of masked atoms around axis through pivot."""
+ relative = coords[mask] - pivot
+ cos_a = torch.cos(angle)
+ sin_a = torch.sin(angle)
+ dot = (relative * axis).sum(dim=-1, keepdim=True)
+ cross = torch.cross(axis.unsqueeze(0).expand_as(relative), relative, dim=-1)
+ rotated = relative * cos_a + cross * sin_a + axis * dot * (1 - cos_a)
+ coords = coords.clone()
+ coords[mask] = rotated + pivot
+ return coords
+
+
+def _apply_rotation(coords, rot_vec):
+ """Apply axis-angle rotation around center of mass."""
+ angle = rot_vec.norm()
+ if angle < 1e-6:
+ return coords
+ com = coords.mean(dim=0)
+ relative = coords - com
+ axis = rot_vec / angle
+ cos_a = torch.cos(angle)
+ sin_a = torch.sin(angle)
+ dot = (relative * axis).sum(dim=-1, keepdim=True)
+ cross = torch.cross(axis.unsqueeze(0).expand_as(relative), relative, dim=-1)
+ rotated = relative * cos_a + cross * sin_a + axis * dot * (1 - cos_a)
+ return rotated + com
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
deleted file mode 100644
index 5b17c95..0000000
--- a/src/utils/metrics.py
+++ /dev/null
@@ -1,108 +0,0 @@
-"""
-Metrics utilities for FlowFix training.
-
-Contains functions for:
-- RMSD calculation
-- Success rate calculation
-- DDP metric gathering
-"""
-
-from typing import List, Dict, Any
-import numpy as np
-import torch
-import torch.distributed as dist
-
-
-def compute_success_rates(rmsds: np.ndarray) -> Dict[str, float]:
- """
- Compute success rates at various RMSD thresholds.
-
- Args:
- rmsds: Array of RMSD values
-
- Returns:
- Dict with success rates: 'success_2A', 'success_1A', 'success_05A'
- """
- return {
- "success_2A": float(np.mean(rmsds < 2.0) * 100),
- "success_1A": float(np.mean(rmsds < 1.0) * 100),
- "success_05A": float(np.mean(rmsds < 0.5) * 100),
- }
-
-
-def gather_metrics_ddp(
- local_metrics: Dict[str, List[float]],
- world_size: int,
- device: torch.device,
-) -> Dict[str, np.ndarray]:
- """
- Gather metrics from all DDP processes.
-
- Handles variable-length lists across processes by padding before all_gather.
-
- Args:
- local_metrics: Dict of metric_name -> list of values
- world_size: Number of DDP processes
- device: Torch device
-
- Returns:
- Dict of metric_name -> numpy array of all gathered values
- """
- gathered = {}
-
- for metric_name, values in local_metrics.items():
- values_tensor = torch.tensor(values, device=device)
-
- # Gather sizes from all ranks
- local_size = torch.tensor([len(values)], device=device, dtype=torch.long)
- all_sizes = [
- torch.zeros(1, device=device, dtype=torch.long) for _ in range(world_size)
- ]
- dist.all_gather(all_sizes, local_size)
-
- # Find max size for padding
- max_size = max(s.item() for s in all_sizes)
-
- # Pad tensor to max size
- padded = torch.zeros(max_size, device=device)
- padded[: len(values)] = values_tensor
-
- # Gather padded tensors
- gathered_tensors = [
- torch.zeros(max_size, device=device) for _ in range(world_size)
- ]
- dist.all_gather(gathered_tensors, padded)
-
- # Unpad and concatenate
- unpadded = [
- gathered_tensors[i][: all_sizes[i].item()] for i in range(world_size)
- ]
- gathered[metric_name] = torch.cat(unpadded).cpu().numpy()
-
- return gathered
-
-
-def average_metrics_ddp(
- local_metrics: List[float],
- world_size: int,
- device: torch.device,
-) -> float:
- """
- Average a list of metrics across all DDP processes.
-
- Args:
- local_metrics: List of metric values from this process
- world_size: Number of DDP processes
- device: Torch device
-
- Returns:
- Averaged metric value
- """
- # Convert to tensor and compute local stats
- tensor = torch.tensor(local_metrics, device=device)
-
- # All-reduce sum
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
-
- # Average across GPUs and samples
- return tensor.mean().item() / world_size
diff --git a/src/utils/plot.py b/src/utils/plot.py
deleted file mode 100644
index f12b244..0000000
--- a/src/utils/plot.py
+++ /dev/null
@@ -1,762 +0,0 @@
-import matplotlib.pyplot as plt
-import numpy as np
-import os
-import seaborn as sns
-from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score
-import itertools
-from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
-from scipy.stats import pearsonr
-import pandas as pd
-
-def plot_scatter(y_true, y_pred, save_dir, type='valid', logger=None):
- """
- Draw scatter plot of true vs predicted values with regression metrics.
-
- Args:
- y_true: Ground truth values
- y_pred: Predicted values
- save_dir: Directory to save the plot
- type: Type of data ('train' or 'valid')
- logger: Logger object for logging messages
- """
- plt.figure(figsize=(10, 10))
-
- # Calculate metrics
- rmse = np.sqrt(mean_squared_error(y_true, y_pred))
- mae = mean_absolute_error(y_true, y_pred)
- r2 = r2_score(y_true, y_pred)
- pearson_corr, _ = pearsonr(y_true, y_pred)
-
- # Create scatter plot
- plt.scatter(y_true, y_pred, alpha=0.5, color='blue', label='Data points')
-
- # Plot the perfect prediction line
- min_val = min(min(y_true), min(y_pred))
- max_val = max(max(y_true), max(y_pred))
- plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect prediction')
-
- # Add metrics text box
- metrics_text = f'RMSE: {rmse:.4f}\nMAE: {mae:.4f}\nR²: {r2:.4f}\nPearson: {pearson_corr:.4f}'
- plt.text(0.05, 0.95, metrics_text,
- transform=plt.gca().transAxes,
- bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
- verticalalignment='top',
- fontsize=10)
-
- plt.xlabel('True Values')
- plt.ylabel('Predicted Values')
- plt.title(f'{type.capitalize()} Set: True vs Predicted Values')
- plt.legend()
- plt.grid(True, linestyle='--', alpha=0.7)
-
- # Make plot square with equal axes
- plt.axis('square')
-
- # Add colorbar to show density
- from scipy.stats import gaussian_kde
- xy = np.vstack([y_true, y_pred])
- z = gaussian_kde(xy)(xy)
- idx = z.argsort()
- x, y, z = y_true[idx], y_pred[idx], z[idx]
- scatter = plt.scatter(x, y, c=z, s=50, alpha=0.5, cmap='viridis')
- plt.colorbar(scatter, label='Density')
-
- save_path = os.path.join(save_dir, f'{type}_scatter_plot.png')
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
- plt.close()
-
- if logger:
- logger.info(f"Scatter plot saved to {save_path}")
-
- return {
- 'RMSE': rmse,
- 'MAE': mae,
- 'R2': r2,
- 'Pearson': pearson_corr
- }
-
-def plot_losses(train_loss, valid_loss, save_dir, best_epoch=None, logger=None):
- """
- Plot training and validation losses over epochs
-
- Args:
- train_loss: List of training loss values
- valid_loss: List of validation loss values
- save_dir: Directory to save the plot
- best_epoch: Best epoch index (0-based)
- logger: Logger object
- """
- if not train_loss or not valid_loss:
- if logger:
- logger.warning("No loss data to plot")
- return
-
- # Get epochs for x-axis (assuming train_loss and valid_loss have the same length)
- start_epoch = 1 # default start epoch
- epochs = list(range(start_epoch, start_epoch + len(train_loss)))
-
- plt.figure(figsize=(10, 6))
- plt.plot(epochs, train_loss, 'b-', label='Training Loss')
- plt.plot(epochs, valid_loss, 'r-', label='Validation Loss')
-
- # Always show best_epoch if provided, but display all epochs regardless
- if best_epoch:
- plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch')
-
- plt.title('Training and Validation Loss')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend()
- plt.grid(True)
-
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'loss_history.png'), dpi=300)
- plt.close()
-
- if logger:
- logger.info(f"Loss history plot saved to {save_dir}/loss_history.png")
-
-def plot_metrics(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None):
- """Plot metrics history
-
- Args:
- train_metrics: List of training metrics dictionaries
- valid_metrics: List of validation metrics dictionaries
- save_dir: Directory to save plots
- best_epoch: Marker for best epoch (optional)
- logger: Logger for logging messages
- """
- if not train_metrics or not valid_metrics:
- if logger:
- logger.warning("No metrics to plot")
- return
-
- # Create figures directory
- figures_dir = os.path.join(save_dir, 'figures')
- os.makedirs(figures_dir, exist_ok=True)
-
- # Plot regression metrics
- plot_regression_metrics(train_metrics, valid_metrics, figures_dir, best_epoch=best_epoch, logger=logger)
-
- if logger:
- logger.info(f"Saved metrics plots to {figures_dir}")
-
-def plot_regression_metrics(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None):
- """Plot regression metrics
-
- Args:
- train_metrics: List of training metrics dictionaries
- valid_metrics: List of validation metrics dictionaries
- save_dir: Directory to save plots
- best_epoch: Marker for best epoch (optional)
- logger: Logger for logging messages
- """
- # Check if metrics are available
- if (
- len(valid_metrics) == 0 or
- 'metrics' not in valid_metrics[0] or
- 'regression' not in valid_metrics[0]['metrics']
- ):
- if logger:
- logger.warning("No regression metrics to plot")
- return
-
- # Extract regression metrics
- epochs = list(range(1, len(train_metrics) + 1))
-
- # Get all metrics (will be available in validation results)
- metric_names = list(valid_metrics[0]['metrics']['regression'].keys())
-
- # Create plot for each metric
- for metric_name in metric_names:
- try:
- valid_values = [valid_metrics[i]['metrics']['regression'][metric_name] for i in range(len(valid_metrics))]
-
- plt.figure(figsize=(10, 6))
-
- # Plot only validation values (training may not have all metrics)
- plt.plot(epochs, valid_values, 'b-', label=f'Validation {metric_name}')
-
- if best_epoch:
- # Mark best epoch
- plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})')
-
- # Get best value
- best_value = valid_values[best_epoch - 1] if 0 <= best_epoch - 1 < len(valid_values) else None
- if best_value is not None:
- plt.scatter([best_epoch], [best_value], color='r', s=100, zorder=5)
- plt.text(best_epoch, best_value, f' {best_value:.4f}', verticalalignment='bottom')
-
- plt.xlabel('Epoch')
- plt.ylabel(metric_name)
- plt.title(f'Regression {metric_name} vs Epoch')
- plt.grid(True, linestyle='--', alpha=0.7)
- plt.legend()
-
- # Save figure
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, f'regression_{metric_name.lower()}.png'), dpi=300)
- plt.close()
-
- except Exception as e:
- if logger:
- logger.warning(f"Error plotting {metric_name}: {str(e)}")
- continue
-
- if logger:
- logger.info(f"Saved regression metrics plots to {save_dir}")
-
-def plot_confusion_matrix(cm, classes, save_dir, title='Confusion Matrix', task_type=None, cmap=plt.cm.Blues):
- """
- Plot confusion matrix
-
- Args:
- cm: Confusion matrix
- classes: List of class names
- save_dir: Directory to save the plot
- title: Plot title
- task_type: Task type for filename (optional)
- cmap: Colormap
- """
- plt.figure(figsize=(8, 6))
- plt.imshow(cm, interpolation='nearest', cmap=cmap)
- plt.title(title)
- plt.colorbar()
-
- tick_marks = np.arange(len(classes))
- plt.xticks(tick_marks, classes, rotation=45)
- plt.yticks(tick_marks, classes)
-
- fmt = 'd'
- thresh = cm.max() / 2.
- for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
- plt.text(j, i, format(cm[i, j], fmt),
- horizontalalignment="center",
- color="white" if cm[i, j] > thresh else "black")
-
- plt.tight_layout()
- plt.ylabel('True label')
- plt.xlabel('Predicted label')
-
- # 파일명 설정 (task_type이 있으면 사용, 없으면 기본 파일명)
- filename = f"{task_type}_cm.png" if task_type else "confusion_matrix.png"
- plt.savefig(os.path.join(save_dir, filename))
- plt.close()
-
-def plot_binding_metrics(y_true, y_pred, save_dir, task_type='binding', epoch=None, logger=None):
- """
- Plot metrics for binding or non-binding classification tasks in a single figure
- with 3 subplots: confusion matrix, ROC curve, and PR curve.
-
- Args:
- y_true: True labels
- y_pred: Predicted probabilities
- save_dir: Directory to save the plot
- task_type: 'binding' or 'non_binding'
- epoch: Current epoch number (not used in filename anymore)
- logger: Logger object
- """
- # Convert predictions to binary (0/1) using 0.5 threshold
- y_pred_binary = (y_pred >= 0.5).astype(int)
-
- # Compute confusion matrix
- cm = confusion_matrix(y_true, y_pred_binary)
-
- # Create a figure with 3 subplots
- fig, axes = plt.subplots(1, 3, figsize=(18, 6))
- plt.suptitle(f'{task_type.replace("_", " ").title()} Classification Metrics', fontsize=16)
-
- # 1. Plot confusion matrix
- axes[0].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
- axes[0].set_title('Confusion Matrix')
-
- # Add text annotations to confusion matrix
- thresh = cm.max() / 2.
- for i in range(cm.shape[0]):
- for j in range(cm.shape[1]):
- axes[0].text(j, i, format(cm[i, j], 'd'),
- ha="center", va="center",
- color="white" if cm[i, j] > thresh else "black")
-
- axes[0].set_xticks([0, 1])
- axes[0].set_xticklabels(['Negative', 'Positive'])
- axes[0].set_yticks([0, 1])
- axes[0].set_yticklabels(['Negative', 'Positive'])
- axes[0].set_ylabel('True label')
- axes[0].set_xlabel('Predicted label')
-
- # 2. Plot ROC curve
- fpr, tpr, _ = roc_curve(y_true, y_pred)
- roc_auc = auc(fpr, tpr)
-
- axes[1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
- axes[1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
- axes[1].set_xlim([0.0, 1.0])
- axes[1].set_ylim([0.0, 1.05])
- axes[1].set_xlabel('False Positive Rate')
- axes[1].set_ylabel('True Positive Rate')
- axes[1].set_title('ROC Curve')
- axes[1].legend(loc='lower right')
- axes[1].grid(True, alpha=0.3)
-
- # 3. Plot precision-recall curve
- precision, recall, _ = precision_recall_curve(y_true, y_pred)
- avg_precision = average_precision_score(y_true, y_pred)
-
- axes[2].plot(recall, precision, color='blue', lw=2, label=f'PR curve (AP = {avg_precision:.3f})')
- axes[2].set_xlabel('Recall')
- axes[2].set_ylabel('Precision')
- axes[2].set_title('Precision-Recall Curve')
- axes[2].legend(loc='upper right')
- axes[2].grid(True, alpha=0.3)
- axes[2].set_ylim([0.0, 1.05])
-
- plt.tight_layout()
-
- # Save the combined figure without epoch number (already implemented)
- save_path = os.path.join(save_dir, f'{task_type}_metrics.png')
- plt.savefig(save_path, dpi=300)
- plt.close()
-
- if logger:
- logger.info(f"{task_type.replace('_', ' ').title()} metrics plot saved to {save_path}")
-
-def plot_loss_types(train_metrics, valid_metrics, save_dir, best_epoch=None, logger=None):
- """
- Plot different types of losses over training epochs in a 2x2 subplot format
- """
- # Always start from the first epoch
- if 'epoch' in train_metrics[0]:
- start_epoch = train_metrics[0]['epoch']
- else:
- start_epoch = 1
-
- epochs = list(range(start_epoch, start_epoch + len(train_metrics)))
-
- # Plot loss types (reg_loss, bind_loss, nonbind_loss, total_loss)
- plt.figure(figsize=(16, 12))
- plt.suptitle('Loss Types Over Training', fontsize=16)
-
- # Total Loss
- plt.subplot(2, 2, 1)
- plt.plot(epochs, [m['total_loss'] for m in train_metrics], 'b-', label='Training')
- plt.plot(epochs, [m['losses']['total_loss'] for m in valid_metrics], 'r-', label='Validation')
- if best_epoch:
- plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.title('Total Loss')
- plt.legend()
- plt.grid(True)
-
- # Regression Loss
- plt.subplot(2, 2, 2)
- plt.plot(epochs, [m['reg_loss'] for m in train_metrics], 'b-', label='Training')
- plt.plot(epochs, [m['losses']['reg_loss'] for m in valid_metrics], 'r-', label='Validation')
- if best_epoch:
- plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.title('Regression Loss')
- plt.legend()
- plt.grid(True)
-
- # Binding Loss
- plt.subplot(2, 2, 3)
- plt.plot(epochs, [m['bind_loss'] for m in train_metrics], 'b-', label='Training')
- plt.plot(epochs, [m['losses']['bind_loss'] for m in valid_metrics], 'r-', label='Validation')
- if best_epoch:
- plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.title('Binding Loss')
- plt.legend()
- plt.grid(True)
-
- # Non-binding Loss
- plt.subplot(2, 2, 4)
- plt.plot(epochs, [m['nonbind_loss'] for m in train_metrics], 'b-', label='Training')
- plt.plot(epochs, [m['losses']['nonbind_loss'] for m in valid_metrics], 'r-', label='Validation')
- if best_epoch:
- plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best epoch')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.title('Non-binding Loss')
- plt.legend()
- plt.grid(True)
-
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'loss_types.png'), dpi=300)
- plt.close()
-
- if logger:
- logger.info(f"Loss types plot saved to {save_dir}/loss_types.png")
-
-# Set style
-plt.style.use('seaborn-v0_8')
-sns.set_palette("husl")
-
-def plot_loss_curves(train_losses, valid_losses, save_dir, title="Training History"):
- """
- Plot training and validation loss curves
-
- Args:
- train_losses: List of training losses
- valid_losses: List of validation losses
- save_dir: Directory to save the plot
- title: Plot title
- """
- plt.figure(figsize=(12, 5))
-
- # Loss subplot
- plt.subplot(1, 2, 1)
- epochs = range(1, len(train_losses) + 1)
-
- plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
- plt.plot(epochs, valid_losses, 'r-', label='Validation Loss', linewidth=2)
-
- plt.title('Loss Curves')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend()
- plt.grid(True, alpha=0.3)
-
- # Find best epoch (minimum validation loss)
- best_epoch = np.argmin(valid_losses) + 1
- plt.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
- plt.legend()
-
- # Loss difference subplot
- plt.subplot(1, 2, 2)
- loss_diff = np.array(valid_losses) - np.array(train_losses)
- plt.plot(epochs, loss_diff, 'purple', linewidth=2)
- plt.title('Validation - Training Loss')
- plt.xlabel('Epoch')
- plt.ylabel('Loss Difference')
- plt.grid(True, alpha=0.3)
- plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
-
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'loss_curves.png'), dpi=300, bbox_inches='tight')
- plt.close()
-
- print(f"Loss curves saved to {save_dir}/loss_curves.png")
-
-def plot_r2_curves(train_r2_scores, valid_r2_scores, save_dir, title="R² Score History"):
- """
- Plot training and validation R² score curves
-
- Args:
- train_r2_scores: List of training R² scores
- valid_r2_scores: List of validation R² scores
- save_dir: Directory to save the plot
- title: Plot title
- """
- plt.figure(figsize=(12, 5))
-
- # R² subplot
- plt.subplot(1, 2, 1)
- epochs = range(1, len(train_r2_scores) + 1)
-
- plt.plot(epochs, train_r2_scores, 'b-', label='Training R²', linewidth=2)
- plt.plot(epochs, valid_r2_scores, 'r-', label='Validation R²', linewidth=2)
-
- plt.title('R² Score Curves')
- plt.xlabel('Epoch')
- plt.ylabel('R² Score')
- plt.legend()
- plt.grid(True, alpha=0.3)
-
- # Find best epoch (maximum validation R²)
- best_epoch = np.argmax(valid_r2_scores) + 1
- best_r2 = valid_r2_scores[best_epoch - 1]
- plt.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
- plt.scatter([best_epoch], [best_r2], color='g', s=100, zorder=5)
- plt.text(best_epoch, best_r2, f' {best_r2:.4f}', verticalalignment='bottom', fontweight='bold')
- plt.legend()
-
- # R² difference subplot
- plt.subplot(1, 2, 2)
- r2_diff = np.array(train_r2_scores) - np.array(valid_r2_scores)
- plt.plot(epochs, r2_diff, 'purple', linewidth=2)
- plt.title('Training - Validation R²')
- plt.xlabel('Epoch')
- plt.ylabel('R² Difference')
- plt.grid(True, alpha=0.3)
- plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
-
- # Add interpretation guidelines
- plt.text(0.02, 0.98, 'Positive: Overfitting\nNegative: Underfitting',
- transform=plt.gca().transAxes, verticalalignment='top',
- bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
-
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'r2_curves.png'), dpi=300, bbox_inches='tight')
- plt.close()
-
- print(f"R² curves saved to {save_dir}/r2_curves.png")
-
- # Print summary statistics
- print(f"Best validation R²: {best_r2:.4f} at epoch {best_epoch}")
- print(f"Final validation R²: {valid_r2_scores[-1]:.4f}")
- print(f"Max training R²: {max(train_r2_scores):.4f}")
- print(f"Max validation R²: {max(valid_r2_scores):.4f}")
-
-def plot_predictions(true_values, predictions, save_dir, title="Prediction Results"):
- """Plot prediction vs true values scatter plot"""
- # Create figure with subplots
- fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
-
- # Convert to numpy arrays
- true_values = np.array(true_values)
- predictions = np.array(predictions)
-
- # Calculate metrics
- mse = mean_squared_error(true_values, predictions)
- mae = mean_absolute_error(true_values, predictions)
- rmse = np.sqrt(mse)
- r2 = r2_score(true_values, predictions)
- pearson_r, pearson_p = pearsonr(true_values, predictions)
-
- # 1. Scatter plot
- ax1.scatter(true_values, predictions, alpha=0.6, s=30)
-
- # Perfect prediction line
- min_val = min(true_values.min(), predictions.min())
- max_val = max(true_values.max(), predictions.max())
- ax1.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')
-
- ax1.set_xlabel('True Affinity', fontsize=12)
- ax1.set_ylabel('Predicted Affinity', fontsize=12)
- ax1.set_title(f'Predictions vs True Values\nR² = {r2:.4f}, RMSE = {rmse:.4f}', fontsize=14)
- ax1.legend()
- ax1.grid(True, alpha=0.3)
-
- # 2. Residual plot
- residuals = predictions - true_values
- ax2.scatter(predictions, residuals, alpha=0.6, s=30)
- ax2.axhline(y=0, color='r', linestyle='--', lw=2)
- ax2.set_xlabel('Predicted Affinity', fontsize=12)
- ax2.set_ylabel('Residuals', fontsize=12)
- ax2.set_title(f'Residual Plot\nMAE = {mae:.4f}', fontsize=14)
- ax2.grid(True, alpha=0.3)
-
- # 3. Distribution of residuals
- ax3.hist(residuals, bins=50, alpha=0.7, edgecolor='black')
- ax3.axvline(residuals.mean(), color='r', linestyle='--', lw=2,
- label=f'Mean: {residuals.mean():.4f}')
- ax3.axvline(0, color='g', linestyle='-', lw=2, label='Zero')
- ax3.set_xlabel('Residuals', fontsize=12)
- ax3.set_ylabel('Frequency', fontsize=12)
- ax3.set_title('Distribution of Residuals', fontsize=14)
- ax3.legend()
- ax3.grid(True, alpha=0.3)
-
- # 4. Error distribution
- abs_errors = np.abs(residuals)
- ax4.hist(abs_errors, bins=50, alpha=0.7, edgecolor='black', color='orange')
- ax4.axvline(abs_errors.mean(), color='r', linestyle='--', lw=2,
- label=f'Mean AE: {abs_errors.mean():.4f}')
- ax4.axvline(np.median(abs_errors), color='g', linestyle='--', lw=2,
- label=f'Median AE: {np.median(abs_errors):.4f}')
- ax4.set_xlabel('Absolute Error', fontsize=12)
- ax4.set_ylabel('Frequency', fontsize=12)
- ax4.set_title('Distribution of Absolute Errors', fontsize=14)
- ax4.legend()
- ax4.grid(True, alpha=0.3)
-
- # Add overall metrics text
- metrics_text = f"""
- Metrics Summary:
- ├─ RMSE: {rmse:.4f}
- ├─ MAE: {mae:.4f}
- ├─ R²: {r2:.4f}
- ├─ Pearson R: {pearson_r:.4f}
- └─ Pearson P: {pearson_p:.4e}
- """
-
- fig.text(0.02, 0.02, metrics_text, fontsize=10, fontfamily='monospace',
- bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8))
-
- plt.suptitle(title, fontsize=16, fontweight='bold')
- plt.tight_layout()
- plt.subplots_adjust(top=0.93, bottom=0.15)
- plt.savefig(os.path.join(save_dir, 'prediction_results.png'), dpi=300, bbox_inches='tight')
- plt.close()
-
-def plot_error_analysis(true_values, predictions, save_dir):
- """Plot detailed error analysis"""
- true_values = np.array(true_values)
- predictions = np.array(predictions)
- residuals = predictions - true_values
- abs_errors = np.abs(residuals)
-
- fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
-
- # 1. Error vs True Values
- ax1.scatter(true_values, abs_errors, alpha=0.6, s=30)
- ax1.set_xlabel('True Affinity', fontsize=12)
- ax1.set_ylabel('Absolute Error', fontsize=12)
- ax1.set_title('Absolute Error vs True Values', fontsize=14)
- ax1.grid(True, alpha=0.3)
-
- # Add trend line
- z = np.polyfit(true_values, abs_errors, 1)
- p = np.poly1d(z)
- ax1.plot(sorted(true_values), p(sorted(true_values)), "r--", alpha=0.8)
-
- # 2. Error vs Predicted Values
- ax2.scatter(predictions, abs_errors, alpha=0.6, s=30, color='orange')
- ax2.set_xlabel('Predicted Affinity', fontsize=12)
- ax2.set_ylabel('Absolute Error', fontsize=12)
- ax2.set_title('Absolute Error vs Predictions', fontsize=14)
- ax2.grid(True, alpha=0.3)
-
- # 3. Q-Q Plot for residuals
- from scipy import stats
- stats.probplot(residuals, dist="norm", plot=ax3)
- ax3.set_title('Q-Q Plot of Residuals', fontsize=14)
- ax3.grid(True, alpha=0.3)
-
- # 4. Cumulative error distribution
- sorted_errors = np.sort(abs_errors)
- cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors)
- ax4.plot(sorted_errors, cumulative, linewidth=2)
- ax4.set_xlabel('Absolute Error', fontsize=12)
- ax4.set_ylabel('Cumulative Probability', fontsize=12)
- ax4.set_title('Cumulative Error Distribution', fontsize=14)
- ax4.grid(True, alpha=0.3)
-
- # Add percentile lines
- percentiles = [0.5, 0.8, 0.9, 0.95]
- for p in percentiles:
- error_at_p = np.percentile(abs_errors, p * 100)
- ax4.axvline(error_at_p, color='red', linestyle='--', alpha=0.7)
- ax4.text(error_at_p, p, f'{p*100:.0f}%', rotation=90,
- verticalalignment='bottom', fontsize=10)
-
- plt.suptitle('Error Analysis', fontsize=16, fontweight='bold')
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'error_analysis.png'), dpi=300, bbox_inches='tight')
- plt.close()
-
-def plot_affinity_distribution(true_values, predictions, save_dir):
- """Plot distribution comparison of true vs predicted affinities"""
- fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
-
- # 1. True values distribution
- ax1.hist(true_values, bins=50, alpha=0.7, label='True Affinity', color='blue', edgecolor='black')
- ax1.set_xlabel('Affinity', fontsize=12)
- ax1.set_ylabel('Frequency', fontsize=12)
- ax1.set_title('True Affinity Distribution', fontsize=14)
- ax1.legend()
- ax1.grid(True, alpha=0.3)
-
- # 2. Predicted values distribution
- ax2.hist(predictions, bins=50, alpha=0.7, label='Predicted Affinity', color='red', edgecolor='black')
- ax2.set_xlabel('Affinity', fontsize=12)
- ax2.set_ylabel('Frequency', fontsize=12)
- ax2.set_title('Predicted Affinity Distribution', fontsize=14)
- ax2.legend()
- ax2.grid(True, alpha=0.3)
-
- # 3. Overlapped distributions
- ax3.hist(true_values, bins=50, alpha=0.5, label='True Affinity', color='blue', density=True)
- ax3.hist(predictions, bins=50, alpha=0.5, label='Predicted Affinity', color='red', density=True)
- ax3.set_xlabel('Affinity', fontsize=12)
- ax3.set_ylabel('Density', fontsize=12)
- ax3.set_title('Distribution Comparison', fontsize=14)
- ax3.legend()
- ax3.grid(True, alpha=0.3)
-
- plt.tight_layout()
- plt.savefig(os.path.join(save_dir, 'affinity_distributions.png'), dpi=300, bbox_inches='tight')
- plt.close()
-
-def create_summary_report(true_values, predictions, save_dir, model_info=None):
- """Create a comprehensive summary report"""
- true_values = np.array(true_values)
- predictions = np.array(predictions)
-
- # Calculate all metrics
- mse = mean_squared_error(true_values, predictions)
- mae = mean_absolute_error(true_values, predictions)
- rmse = np.sqrt(mse)
- r2 = r2_score(true_values, predictions)
- pearson_r, pearson_p = pearsonr(true_values, predictions)
-
- residuals = predictions - true_values
- abs_errors = np.abs(residuals)
-
- # Create summary text
- report = f"""
-# Model Performance Report
-
-## Model Information
-{f"- Model: {model_info.get('model_name', 'EGNN Affinity Model')}" if model_info else "- Model: EGNN Affinity Model"}
-{f"- Parameters: {model_info.get('total_parameters', 'N/A')}" if model_info else ""}
-{f"- Training Time: {model_info.get('training_time', 'N/A')}" if model_info else ""}
-
-## Dataset Statistics
-- Total Samples: {len(true_values):,}
-- True Affinity Range: [{true_values.min():.4f}, {true_values.max():.4f}]
-- True Affinity Mean ± Std: {true_values.mean():.4f} ± {true_values.std():.4f}
-- Predicted Affinity Range: [{predictions.min():.4f}, {predictions.max():.4f}]
-- Predicted Affinity Mean ± Std: {predictions.mean():.4f} ± {predictions.std():.4f}
-
-## Performance Metrics
-### Primary Metrics
-- Root Mean Square Error (RMSE): {rmse:.4f}
-- Mean Absolute Error (MAE): {mae:.4f}
-- R-squared Score (R²): {r2:.4f}
-- Pearson Correlation: {pearson_r:.4f} (p-value: {pearson_p:.2e})
-
-### Error Statistics
-- Mean Residual: {residuals.mean():.4f}
-- Std Residual: {residuals.std():.4f}
-- Mean Absolute Error: {abs_errors.mean():.4f}
-- Median Absolute Error: {np.median(abs_errors):.4f}
-- 90th Percentile Error: {np.percentile(abs_errors, 90):.4f}
-- 95th Percentile Error: {np.percentile(abs_errors, 95):.4f}
-- Max Absolute Error: {abs_errors.max():.4f}
-
-### Error Distribution
-- Errors < 0.5: {(abs_errors < 0.5).sum():,} ({(abs_errors < 0.5).mean()*100:.1f}%)
-- Errors < 1.0: {(abs_errors < 1.0).sum():,} ({(abs_errors < 1.0).mean()*100:.1f}%)
-- Errors < 2.0: {(abs_errors < 2.0).sum():,} ({(abs_errors < 2.0).mean()*100:.1f}%)
-- Errors ≥ 2.0: {(abs_errors >= 2.0).sum():,} ({(abs_errors >= 2.0).mean()*100:.1f}%)
-
-## Summary
-The model achieves {'excellent' if r2 > 0.8 else 'good' if r2 > 0.6 else 'moderate' if r2 > 0.4 else 'poor'}
-performance with an R² score of {r2:.4f} and RMSE of {rmse:.4f}.
-"""
-
- # Save report
- with open(os.path.join(save_dir, 'performance_report.md'), 'w') as f:
- f.write(report)
-
- # Also create a CSV with detailed results
- results_df = pd.DataFrame({
- 'True_Affinity': true_values,
- 'Predicted_Affinity': predictions,
- 'Residual': residuals,
- 'Absolute_Error': abs_errors
- })
-
- results_df.to_csv(os.path.join(save_dir, 'detailed_results.csv'), index=False)
-
- print("Summary report and detailed results saved.")
-
-def plot_all_results(true_values, predictions, save_dir, model_info=None):
- """Generate all plots and reports"""
- print("Generating comprehensive result analysis...")
-
- # Generate all plots
- plot_predictions(true_values, predictions, save_dir)
- plot_error_analysis(true_values, predictions, save_dir)
- plot_affinity_distribution(true_values, predictions, save_dir)
-
- # Create summary report
- create_summary_report(true_values, predictions, save_dir, model_info)
-
- print(f"All results saved to: {save_dir}")
\ No newline at end of file
diff --git a/src/utils/sampling_torsion.py b/src/utils/sampling_torsion.py
new file mode 100644
index 0000000..2658f61
--- /dev/null
+++ b/src/utils/sampling_torsion.py
@@ -0,0 +1,111 @@
+"""
+ODE sampling for SE(3) + Torsion decomposition.
+
+Applies Torsion → Translation → Rotation at each Euler step.
+"""
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from .losses_torsion import _rodrigues_rotate, _apply_rotation
+
+
+@torch.no_grad()
+def sample_trajectory_torsion(
+ model: torch.nn.Module,
+ protein_batch,
+ ligand_batch,
+ x0: torch.Tensor,
+ timesteps: torch.Tensor,
+ rotatable_edges: torch.Tensor,
+ mask_rotate: torch.Tensor,
+ return_trajectory: bool = False,
+) -> dict:
+ """
+ Sample trajectory using SE(3) + Torsion decomposition.
+
+ Args:
+ model: Flow matching model (ProteinLigandFlowMatchingTorsion)
+ protein_batch: Protein PyG batch
+ ligand_batch: Ligand PyG batch
+ x0: Initial docked coordinates [N_atoms, 3]
+ timesteps: Integration timesteps [num_steps + 1]
+ rotatable_edges: [M, 2] rotatable bond atom indices
+ mask_rotate: [M, N] boolean mask for torsion
+ return_trajectory: Whether to return full trajectory
+
+ Returns:
+ Dict with 'final_coords', optionally 'trajectory'
+ """
+ device = x0.device
+ batch_size = ligand_batch.batch.max().item() + 1
+ num_steps = len(timesteps) - 1
+
+ current_coords = x0.clone()
+ trajectory = [current_coords.clone()] if return_trajectory else []
+
+ for step in range(num_steps):
+ t_current = timesteps[step]
+ t_next = timesteps[step + 1]
+ dt = (t_next - t_current).item()
+
+ t = torch.ones(batch_size, device=device) * t_current
+
+ ligand_batch_t = ligand_batch.clone()
+ ligand_batch_t.pos = current_coords.clone()
+
+ # Unwrap DDP if needed
+ model_fn = model.module if isinstance(model, DDP) else model
+ output = model_fn(protein_batch, ligand_batch_t, t, rotatable_edges=rotatable_edges)
+
+ # Apply per molecule: Torsion → Translation → Rotation
+ for b in range(batch_size):
+ mol_mask = (ligand_batch.batch == b)
+ mol_coords = current_coords[mol_mask].clone()
+ mol_indices = torch.where(mol_mask)[0]
+ n_atoms = mol_indices.shape[0]
+ offset = mol_indices[0].item()
+
+ # 1. Torsion
+ if output['torsion'].numel() > 0 and mask_rotate.shape[0] > 0:
+ for m in range(mask_rotate.shape[0]):
+ angle = dt * output['torsion'][m].item()
+ if abs(angle) < 1e-6:
+ continue
+
+ mask_local = mask_rotate[m, offset:offset + n_atoms]
+ if not mask_local.any():
+ continue
+
+ src, dst = rotatable_edges[m]
+ src_l = src.item() - offset
+ dst_l = dst.item() - offset
+ if not (0 <= src_l < n_atoms and 0 <= dst_l < n_atoms):
+ continue
+
+ axis = mol_coords[dst_l] - mol_coords[src_l]
+ axis_norm = axis.norm()
+ if axis_norm < 1e-6:
+ continue
+ axis = axis / axis_norm
+
+ mol_coords = _rodrigues_rotate(
+ mol_coords, mask_local, axis, mol_coords[dst_l],
+ torch.tensor(angle, device=device)
+ )
+
+ # 2. Translation
+ mol_coords = mol_coords + dt * output['translation'][b]
+
+ # 3. Rotation
+ mol_coords = _apply_rotation(mol_coords, dt * output['rotation'][b])
+
+ current_coords[mol_mask] = mol_coords
+
+ if return_trajectory:
+ trajectory.append(current_coords.clone())
+
+ result = {"final_coords": current_coords}
+ if return_trajectory:
+ result["trajectory"] = trajectory
+ return result
diff --git a/src/utils/train.py b/src/utils/train.py
deleted file mode 100644
index c106f05..0000000
--- a/src/utils/train.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import dgl, torch, scipy
-import numpy as np
-
-import torch.nn as nn
-import torch.nn.functional as F
-
-from tqdm import tqdm
-from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, matthews_corrcoef
-from scipy.stats import spearmanr
-
-
-def custom_loss(pred, true, direction='over', penalty_factor=2.0):
- base_loss = (pred - true) ** 2
- if direction == 'over':
- penalty = torch.where(pred > true, penalty_factor * base_loss, base_loss)
- elif direction == 'under':
- penalty = torch.where(pred < true, penalty_factor * base_loss, base_loss)
- else:
- penalty = base_loss
-
- return torch.mean(penalty)
-
-
-
-def run_train_epoch(model, loader, optimizer, scheduler, device='cpu'):
- """Run a single training epoch.
-
- Args:
- model: The model to train
- loader: DataLoader providing batches
- optimizer: Optimizer for updating weights
- scheduler: Learning rate scheduler
- device: Device to use for computation
-
- Returns:
- Dictionary containing average losses and metrics
- """
- model.train()
-
- total_loss = 0
- num_batches = 0
-
- for batch in tqdm(loader, desc="Training", total=len(loader)):
- prot_data, ligand_graph, ligand_mask, interaction_mask, reg_true, bind_true, nonbind_true = batch
-
- # Move data to device
- for key in prot_data:
- prot_data[key] = prot_data[key].to(device)
-
- ligand_graph = ligand_graph.to(device)
- ligand_mask = ligand_mask.to(device)
- interaction_mask = interaction_mask.to(device)
-
- reg_true = reg_true.to(device)
-
- # Forward pass - regression only
- reg_pred = model(prot_data, ligand_graph, interaction_mask)
-
- # Calculate loss using custom loss function
-#loss = custom_loss(reg_pred.squeeze(), reg_true, direction='over', penalty_factor=2.0)
- # Huber
- loss = F.huber_loss(reg_pred.squeeze(), reg_true)
-
- optimizer.zero_grad(set_to_none=True)
- loss.backward()
-
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
-
- optimizer.step()
-
- total_loss += loss.item()
- num_batches += 1
-
- torch.cuda.empty_cache()
-
- scheduler.step()
-
- return {
- 'total_loss': total_loss / max(1, num_batches)
- }
-
-@torch.no_grad()
-def run_eval_epoch(model, loader, device='cpu'):
- """Run a single evaluation epoch.
-
- Args:
- model: The model to evaluate
- loader: DataLoader providing batches
- device: Device to use for computation
-
- Returns:
- Dictionary containing losses and metrics for regression task
- """
- model.eval()
-
- all_reg_true = []
- all_reg_pred = []
-
- total_loss = 0
- num_batches = 0
-
- for batch in tqdm(loader, desc="Evaluation", total=len(loader)):
- prot_data, ligand_graph, ligand_mask, interaction_mask, reg_true, bind_true, nonbind_true = batch
-
- # Move data to device
- for key in prot_data:
- prot_data[key] = prot_data[key].to(device)
-
- ligand_graph = ligand_graph.to(device)
- ligand_mask = ligand_mask.to(device)
- interaction_mask = interaction_mask.to(device)
-
- reg_true = reg_true.to(device)
-
- # Forward pass - regression only
- reg_pred = model(prot_data, ligand_graph, interaction_mask)
-
- # Calculate loss using custom loss function
-#loss = custom_loss(reg_pred.squeeze(), reg_true, direction='over', penalty_factor=2.0)
- loss = F.huber_loss(reg_pred.squeeze(), reg_true)
-
- # Accumulate losses
- total_loss += loss.item()
- num_batches += 1
-
- # Collect predictions and true values
- all_reg_true.append(reg_true)
- all_reg_pred.append(reg_pred)
-
- torch.cuda.empty_cache()
-
- # Concatenate all predictions and true values
- reg_true = torch.cat(all_reg_true, dim=0)
- reg_pred = torch.cat(all_reg_pred, dim=0)
-
- # Calculate average loss
- avg_loss = total_loss / max(1, num_batches)
-
- # Compute metrics for regression task
- reg_metrics = compute_regression_metrics(reg_true, reg_pred)
-
- return {
- 'losses': {
- 'total_loss': avg_loss
- },
- 'metrics': {
- 'regression': reg_metrics
- }
- }
-
-
-@torch.no_grad()
-def compute_regression_metrics(true, pred):
- """Compute comprehensive regression metrics"""
- true = true.cpu().numpy()
- pred = pred.cpu().numpy().squeeze()
-
- # Basic error metrics
- mse = np.mean((true - pred) ** 2)
- rmse = np.sqrt(mse)
- mae = np.mean(np.abs(true - pred))
-
- # Correlation metrics
- pearson = np.corrcoef(true, pred)[0, 1] if len(true) > 1 else 0
- spearman = spearmanr(true, pred)[0] if len(true) > 1 else 0
-
- # Explained variance metrics
- ss_tot = np.sum((true - true.mean()) ** 2)
- ss_res = np.sum((true - pred) ** 2)
- r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0
-
- residuals = true - pred
- mean_bias = np.mean(residuals) # Bias (systematic error)
-
- return {
- 'MSE': mse,
- 'RMSE': rmse,
- 'MAE': mae,
- 'Pearson': pearson,
- 'Spearman': spearman,
- 'R2': r2,
- 'Mean_Bias': mean_bias
- }
-@torch.no_grad()
-def compute_binary_metrics(true, pred):
- """Compute essential binary classification metrics"""
- true = true.cpu().numpy()
- pred_probs = torch.sigmoid(pred).cpu().numpy().squeeze()
- pred_class = (pred_probs >= 0.5).astype(int)
-
- tn, fp, fn, tp = confusion_matrix(true, pred_class).ravel()
-
- return {
- 'accuracy': accuracy_score(true, pred_class),
- 'precision': precision_score(true, pred_class, zero_division=0),
- 'recall': recall_score(true, pred_class, zero_division=0),
- 'f1': f1_score(true, pred_class, zero_division=0),
- 'auc_roc': roc_auc_score(true, pred_probs) if len(np.unique(true)) > 1 else 0,
- 'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
- 'balanced_accuracy': (tp / (tp + fn) + tn / (tn + fp)) / 2 if (tp + fn) > 0 and (tn + fp) > 0 else 0,
- 'mcc': matthews_corrcoef(true, pred_class),
- 'f1_macro': f1_score(true, pred_class, average='macro', zero_division=0),
- 'npv': tn / (tn + fn) if (tn + fn) > 0 else 0
- }
-
diff --git a/tests/test_joint_model.py b/tests/test_joint_model.py
deleted file mode 100644
index 126afd2..0000000
--- a/tests/test_joint_model.py
+++ /dev/null
@@ -1,335 +0,0 @@
-"""
-Test script for the Joint Graph Architecture.
-Verifies forward pass shapes and basic functionality.
-"""
-
-import sys
-import os
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-
-import torch
-from torch_geometric.data import Data, Batch
-
-
-def create_dummy_protein_batch(batch_size=2, device='cpu'):
- """Create dummy protein batch for testing."""
- graphs = []
- for i in range(batch_size):
- n_residues = 30 + i * 10 # Variable sizes
- n_edges = n_residues * 8 # ~8 edges per node
-
- graph = Data(
- x=torch.randn(n_residues, 76, device=device),
- pos=torch.randn(n_residues, 3, device=device) * 10,
- edge_index=torch.randint(0, n_residues, (2, n_edges), device=device),
- edge_attr=torch.randn(n_edges, 39, device=device),
- node_vector_features=torch.randn(n_residues, 31, 3, device=device),
- edge_vector_features=torch.randn(n_edges, 8, 3, device=device),
- )
- graphs.append(graph)
-
- return Batch.from_data_list(graphs)
-
-
-def create_dummy_ligand_batch(batch_size=2, device='cpu'):
- """Create dummy ligand batch for testing."""
- graphs = []
- for i in range(batch_size):
- n_atoms = 20 + i * 5 # Variable sizes
- n_edges = n_atoms * 4 # ~4 edges per atom
-
- graph = Data(
- x=torch.randn(n_atoms, 122, device=device),
- pos=torch.randn(n_atoms, 3, device=device) * 5,
- edge_index=torch.randint(0, n_atoms, (2, n_edges), device=device),
- edge_attr=torch.randn(n_edges, 44, device=device),
- )
- graphs.append(graph)
-
- return Batch.from_data_list(graphs)
-
-
-def test_build_cross_edges():
- """Test cross-edge construction."""
- from src.models.network import build_cross_edges
-
- protein_pos = torch.randn(30, 3) * 10
- ligand_pos = torch.randn(20, 3) * 5
-
- edge_index = build_cross_edges(
- protein_pos, ligand_pos,
- distance_cutoff=10.0, max_neighbors=8
- )
-
- print(f"Cross-edge construction:")
- print(f" Protein nodes: 30, Ligand nodes: 20")
- print(f" Cross edges: {edge_index.shape[1]} (bidirectional)")
- print(f" Edge index shape: {edge_index.shape}")
-
- # Verify bidirectional: first half P→L, second half L→P
- n_half = edge_index.shape[1] // 2
- print(f" P→L edges: {n_half}, L→P edges: {edge_index.shape[1] - n_half}")
-
- assert edge_index.shape[0] == 2
- assert edge_index.shape[1] > 0
- print(" PASSED")
-
-
-def test_build_intra_edges():
- """Test intra-edge construction."""
- from src.models.network import build_intra_edges
-
- n_nodes = 30
- pos = torch.randn(n_nodes, 3) * 10
-
- # Create some existing edges (simulating pre-computed backbone edges)
- existing_src = torch.arange(0, n_nodes - 1)
- existing_dst = torch.arange(1, n_nodes)
- existing_edge_index = torch.stack([
- torch.cat([existing_src, existing_dst]),
- torch.cat([existing_dst, existing_src])
- ]) # Bidirectional sequential edges
-
- edge_index = build_intra_edges(
- pos, existing_edge_index,
- distance_cutoff=15.0, max_neighbors=8
- )
-
- print(f"\nIntra-edge construction:")
- print(f" Nodes: {n_nodes}")
- print(f" Existing edges: {existing_edge_index.shape[1]}")
- print(f" New intra edges: {edge_index.shape[1]}")
- print(f" Edge index shape: {edge_index.shape}")
-
- assert edge_index.shape[0] == 2
- # Verify no overlap with existing edges
- if edge_index.shape[1] > 0:
- existing_keys = set((existing_edge_index[0].long() * n_nodes + existing_edge_index[1].long()).tolist())
- new_keys = (edge_index[0].long() * n_nodes + edge_index[1].long()).tolist()
- overlap = sum(1 for k in new_keys if k in existing_keys)
- print(f" Overlap with existing: {overlap} (should be 0)")
- assert overlap == 0
- print(" PASSED")
-
-
-def test_joint_network():
- """Test JointProteinLigandNetwork forward pass."""
- from src.models.network import JointProteinLigandNetwork
-
- device = 'cpu'
- batch_size = 2
-
- network = JointProteinLigandNetwork(
- protein_input_scalar_dim=76,
- protein_input_vector_dim=31,
- protein_input_edge_scalar_dim=39,
- protein_input_edge_vector_dim=8,
- ligand_input_scalar_dim=122,
- ligand_input_edge_scalar_dim=44,
- hidden_scalar_dim=32, # Small for testing
- hidden_vector_dim=8,
- hidden_edge_dim=32,
- cross_edge_distance_cutoff=15.0, # Larger cutoff for random data
- cross_edge_max_neighbors=8,
- cross_edge_num_rbf=16,
- intra_edge_distance_cutoff=15.0,
- intra_edge_max_neighbors=8,
- num_layers=2, # Few layers for speed
- dropout=0.0,
- condition_dim=64,
- ).to(device)
-
- protein_batch = create_dummy_protein_batch(batch_size, device)
- ligand_batch = create_dummy_ligand_batch(batch_size, device)
-
- # Create dummy time condition [B, condition_dim]
- time_condition = torch.randn(batch_size, 64, device=device)
-
- print(f"\nJointProteinLigandNetwork forward pass:")
- print(f" Protein nodes: {protein_batch.num_nodes}")
- print(f" Ligand nodes: {ligand_batch.num_nodes}")
-
- velocity = network(protein_batch, ligand_batch, time_condition=time_condition)
-
- print(f" velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])")
-
- assert velocity.shape == (ligand_batch.num_nodes, 3)
- print(" PASSED")
-
-
-def test_full_model():
- """Test ProteinLigandFlowMatchingJoint end-to-end."""
- from src.models.flowmatching import ProteinLigandFlowMatchingJoint
-
- device = 'cpu'
- batch_size = 2
-
- model = ProteinLigandFlowMatchingJoint(
- protein_input_scalar_dim=76,
- protein_input_vector_dim=31,
- protein_input_edge_scalar_dim=39,
- protein_input_edge_vector_dim=8,
- ligand_input_scalar_dim=122,
- ligand_input_edge_scalar_dim=44,
- hidden_scalar_dim=32,
- hidden_vector_dim=8,
- hidden_edge_dim=32,
- cross_edge_distance_cutoff=15.0,
- cross_edge_max_neighbors=8,
- cross_edge_num_rbf=16,
- intra_edge_distance_cutoff=15.0,
- intra_edge_max_neighbors=8,
- joint_num_layers=4,
- hidden_dim=64,
- dropout=0.0,
- use_esm_embeddings=False,
- ).to(device)
-
- protein_batch = create_dummy_protein_batch(batch_size, device)
- ligand_batch = create_dummy_ligand_batch(batch_size, device)
- t = torch.rand(batch_size, device=device)
-
- print(f"\nProteinLigandFlowMatchingJoint full forward pass:")
- print(f" Protein nodes: {protein_batch.num_nodes}")
- print(f" Ligand nodes: {ligand_batch.num_nodes}")
- print(f" Batch size: {batch_size}")
-
- velocity = model(protein_batch, ligand_batch, t)
-
- print(f" Velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])")
- assert velocity.shape == (ligand_batch.num_nodes, 3)
-
- # Test backward pass
- loss = velocity.pow(2).mean()
- loss.backward()
- print(f" Loss: {loss.item():.6f}")
- print(f" Backward pass: OK")
-
- # Count parameters
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f" Trainable parameters: {total_params:,}")
- print(" PASSED")
-
-
-def test_model_builder():
- """Test model builder with joint architecture config."""
- from src.utils.model_builder import build_model
-
- config = {
- 'architecture': 'joint',
- 'protein_input_scalar_dim': 76,
- 'protein_input_vector_dim': 31,
- 'protein_input_edge_scalar_dim': 39,
- 'protein_input_edge_vector_dim': 8,
- 'ligand_input_scalar_dim': 122,
- 'ligand_input_edge_scalar_dim': 44,
- 'hidden_scalar_dim': 32,
- 'hidden_vector_dim': 8,
- 'hidden_edge_dim': 32,
- 'cross_edge_distance_cutoff': 15.0,
- 'cross_edge_max_neighbors': 8,
- 'cross_edge_num_rbf': 16,
- 'intra_edge_distance_cutoff': 15.0,
- 'intra_edge_max_neighbors': 8,
- 'joint_num_layers': 4,
- 'hidden_dim': 64,
- 'dropout': 0.0,
- 'use_esm_embeddings': False,
- }
-
- device = 'cpu'
- model = build_model(config, device)
-
- print(f"\nModel builder test:")
- print(f" Model type: {type(model).__name__}")
- assert type(model).__name__ == 'ProteinLigandFlowMatchingJoint'
- print(" PASSED")
-
-
-def test_esm_integration():
- """Test ESM gated concatenation integration."""
- from src.models.flowmatching import ProteinLigandFlowMatchingJoint
-
- device = 'cpu'
- batch_size = 2
-
- model = ProteinLigandFlowMatchingJoint(
- protein_input_scalar_dim=76,
- protein_input_vector_dim=31,
- protein_input_edge_scalar_dim=39,
- protein_input_edge_vector_dim=8,
- ligand_input_scalar_dim=122,
- ligand_input_edge_scalar_dim=44,
- hidden_scalar_dim=32,
- hidden_vector_dim=8,
- hidden_edge_dim=32,
- cross_edge_distance_cutoff=15.0,
- cross_edge_max_neighbors=8,
- cross_edge_num_rbf=16,
- intra_edge_distance_cutoff=15.0,
- intra_edge_max_neighbors=8,
- joint_num_layers=4,
- hidden_dim=64,
- dropout=0.0,
- use_esm_embeddings=True,
- esmc_dim=1152,
- esm3_dim=1536,
- esm_proj_dim=64, # Small for testing
- ).to(device)
-
- protein_batch = create_dummy_protein_batch(batch_size, device)
- ligand_batch = create_dummy_ligand_batch(batch_size, device)
- t = torch.rand(batch_size, device=device)
-
- # Add dummy ESM embeddings
- n_protein = protein_batch.num_nodes
- protein_batch.esmc_embeddings = torch.randn(n_protein, 1152, device=device)
- protein_batch.esm3_embeddings = torch.randn(n_protein, 1536, device=device)
-
- print(f"\nESM gated concatenation test:")
- print(f" Protein nodes: {n_protein}")
- print(f" ESM proj dim: 64")
- print(f" Effective protein scalar dim: 76 + 64 = 140")
-
- velocity = model(protein_batch, ligand_batch, t)
-
- print(f" Velocity shape: {velocity.shape} (expected [{ligand_batch.num_nodes}, 3])")
- assert velocity.shape == (ligand_batch.num_nodes, 3)
-
- # Test backward
- loss = velocity.pow(2).mean()
- loss.backward()
- print(f" Backward pass: OK")
-
- # Verify ESM gate is in computation graph (gradients assigned)
- gate_has_grad = all(p.grad is not None for p in model.esm_gate.parameters())
- print(f" ESM gate gradients: {'OK' if gate_has_grad else 'MISSING'}")
- assert gate_has_grad
-
- # Test without ESM embeddings (should zero-pad)
- model.zero_grad()
- protein_batch2 = create_dummy_protein_batch(batch_size, device)
- velocity2 = model(protein_batch2, ligand_batch, t)
- assert velocity2.shape == (ligand_batch.num_nodes, 3)
- print(f" Without ESM (zero-pad): OK")
-
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f" Trainable parameters: {total_params:,}")
- print(" PASSED")
-
-
-if __name__ == "__main__":
- print("=" * 60)
- print("Testing Joint Graph Architecture")
- print("=" * 60)
-
- test_build_cross_edges()
- test_build_intra_edges()
- test_joint_network()
- test_full_model()
- test_esm_integration()
- test_model_builder()
-
- print("\n" + "=" * 60)
- print("All tests PASSED!")
- print("=" * 60)
diff --git a/train_torsion.py b/train_torsion.py
new file mode 100644
index 0000000..42427af
--- /dev/null
+++ b/train_torsion.py
@@ -0,0 +1,465 @@
+#!/usr/bin/env python
+"""
+Training script for SE(3) + Torsion Decomposition Flow Matching.
+
+Predicts translation [3] + rotation [3] + torsion [M] instead of per-atom velocity [N, 3].
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+from tqdm import tqdm
+import yaml
+import argparse
+from datetime import datetime
+from torch.utils.data import DataLoader
+import wandb
+
+from src.data.dataset_torsion import FlowFixTorsionDataset, collate_torsion_batch
+from src.models.flowmatching_torsion import ProteinLigandFlowMatchingTorsion
+from src.utils.losses_torsion import compute_se3_torsion_loss
+from src.utils.sampling_torsion import sample_trajectory_torsion
+from src.utils.sampling import generate_timestep_schedule
+from src.utils.training_utils import build_optimizer_and_scheduler
+from src.utils.early_stop import EarlyStopping
+from src.utils.utils import set_random_seed
+from src.utils.experiment import ExperimentManager
+from src.utils.wandb_logger import (
+ WandBLogger,
+ extract_module_gradient_norms,
+ extract_parameter_stats,
+)
+
+
+def build_torsion_model(model_config, device):
+ """Build ProteinLigandFlowMatchingTorsion from config."""
+ model = ProteinLigandFlowMatchingTorsion(
+ protein_input_scalar_dim=model_config.get('protein_input_scalar_dim', 76),
+ protein_input_vector_dim=model_config.get('protein_input_vector_dim', 31),
+ protein_input_edge_scalar_dim=model_config.get('protein_input_edge_scalar_dim', 39),
+ protein_input_edge_vector_dim=model_config.get('protein_input_edge_vector_dim', 8),
+ protein_hidden_scalar_dim=model_config.get('protein_hidden_scalar_dim', 128),
+ protein_hidden_vector_dim=model_config.get('protein_hidden_vector_dim', 32),
+ protein_output_scalar_dim=model_config.get('protein_output_scalar_dim', 128),
+ protein_output_vector_dim=model_config.get('protein_output_vector_dim', 32),
+ protein_num_layers=model_config.get('protein_num_layers', 3),
+ ligand_input_scalar_dim=model_config.get('ligand_input_scalar_dim', 121),
+ ligand_input_edge_scalar_dim=model_config.get('ligand_input_edge_scalar_dim', 44),
+ ligand_hidden_scalar_dim=model_config.get('ligand_hidden_scalar_dim', 128),
+ ligand_hidden_vector_dim=model_config.get('ligand_hidden_vector_dim', 16),
+ ligand_output_scalar_dim=model_config.get('ligand_output_scalar_dim', 128),
+ ligand_output_vector_dim=model_config.get('ligand_output_vector_dim', 16),
+ ligand_num_layers=model_config.get('ligand_num_layers', 3),
+ interaction_num_heads=model_config.get('interaction_num_heads', 8),
+ interaction_num_layers=model_config.get('interaction_num_layers', 2),
+ interaction_num_rbf=model_config.get('interaction_num_rbf', 32),
+ interaction_pair_dim=model_config.get('interaction_pair_dim', 64),
+ velocity_hidden_scalar_dim=model_config.get('velocity_hidden_scalar_dim', 128),
+ velocity_hidden_vector_dim=model_config.get('velocity_hidden_vector_dim', 16),
+ velocity_num_layers=model_config.get('velocity_num_layers', 4),
+ hidden_dim=model_config.get('hidden_dim', 256),
+ dropout=model_config.get('dropout', 0.1),
+ use_esm_embeddings=model_config.get('use_esm_embeddings', True),
+ esmc_dim=model_config.get('esmc_dim', 1152),
+ esm3_dim=model_config.get('esm3_dim', 1536),
+ ).to(device)
+ return model
+
+
+class FlowFixTorsionTrainer:
+ """Trainer for SE(3) + Torsion decomposition flow matching."""
+
+ def __init__(self, config):
+ self.config = config
+ self.device = torch.device(config['device'])
+
+ set_random_seed(config.get('seed', 42))
+
+ self.setup_experiment()
+ self.setup_data()
+ self.setup_model()
+ self.setup_optimizer()
+ self.setup_early_stopping()
+
+ self.global_step = 0
+ self.current_epoch = 0
+ self.best_val_success = 0.0
+
+ self.setup_wandb()
+ self.wandb_logger = WandBLogger(enabled=self.wandb_enabled)
+
+ def setup_experiment(self):
+ """Setup experiment manager."""
+ wandb_config = self.config.get('wandb', {})
+ run_name = wandb_config.get('name')
+ if run_name is None:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ run_name = f"torsion_{timestamp}"
+
+ base_dir = self.config.get('experiment', {}).get('base_dir', 'save')
+ self.exp_manager = ExperimentManager(base_dir=base_dir, run_name=run_name, config=self.config)
+ self.checkpoint_dir = self.exp_manager.checkpoints_dir
+ self.exp_manager.logger.info(f"Experiment: {run_name} | Device: {self.device}")
+
+ def setup_data(self):
+ """Setup torsion-aware datasets."""
+ data_config = self.config['data']
+ training_config = self.config['training']
+
+ self.train_dataset = FlowFixTorsionDataset(
+ data_dir=data_config.get('data_dir', 'train_data'),
+ split_file=data_config.get('split_file'),
+ split='train',
+ max_samples=data_config.get('max_train_samples'),
+ seed=self.config.get('seed', 42),
+ loading_mode=data_config.get('loading_mode', 'lazy'),
+ )
+ self.val_dataset = FlowFixTorsionDataset(
+ data_dir=data_config.get('data_dir', 'train_data'),
+ split_file=data_config.get('split_file'),
+ split='valid',
+ max_samples=data_config.get('max_val_samples'),
+ seed=self.config.get('seed', 42),
+ loading_mode=data_config.get('loading_mode', 'lazy'),
+ )
+
+ self.train_loader = DataLoader(
+ self.train_dataset,
+ batch_size=training_config['batch_size'],
+ shuffle=True,
+ num_workers=data_config.get('num_workers', 4),
+ collate_fn=collate_torsion_batch,
+ )
+ self.val_loader = DataLoader(
+ self.val_dataset,
+ batch_size=training_config.get('val_batch_size', 4),
+ shuffle=False,
+ num_workers=data_config.get('num_workers', 4),
+ collate_fn=collate_torsion_batch,
+ )
+
+ self.exp_manager.logger.info(
+ f"Train: {len(self.train_dataset)} | Val: {len(self.val_dataset)} PDBs"
+ )
+
+ def setup_model(self):
+ """Initialize torsion model."""
+ self.model = build_torsion_model(self.config['model'], self.device)
+ total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+ self.exp_manager.logger.info(f"Model: {total_params:,} trainable params (SE(3)+Torsion)")
+
+ def setup_optimizer(self):
+ """Setup optimizer and scheduler."""
+ self.optimizer, self.scheduler = build_optimizer_and_scheduler(
+ self.model, self.config['training']
+ )
+
+ def setup_early_stopping(self):
+ """Setup early stopping."""
+ val_config = self.config['training'].get('validation', {})
+ self.early_stopper = EarlyStopping(
+ mode='min',
+ patience=val_config.get('early_stopping_patience', 50),
+ restore_best_weights=True,
+ save_dir=str(self.checkpoint_dir),
+ )
+
+ def setup_wandb(self):
+ """Setup WandB."""
+ wandb_config = self.config.get('wandb', {})
+ if not wandb_config.get('enabled', False):
+ self.wandb_enabled = False
+ return
+
+ self.wandb_enabled = True
+ wandb.init(
+ project=wandb_config.get('project', 'protein-ligand-flowfix'),
+ entity=wandb_config.get('entity'),
+ name=self.exp_manager.run_name,
+ tags=wandb_config.get('tags', []),
+ dir=self.exp_manager.get_wandb_dir(),
+ config={
+ 'model': self.config['model'],
+ 'training': self.config['training'],
+ 'output_mode': 'torsion',
+ },
+ )
+
+ def train_step(self, batch):
+ """Single training step for SE(3) + Torsion."""
+ ligand_batch = batch['ligand_graph'].to(self.device)
+ protein_batch = batch['protein_graph'].to(self.device)
+ coords_x0 = batch['ligand_coords_x0'].to(self.device)
+ coords_x1 = batch['ligand_coords_x1'].to(self.device)
+ batch_size = len(batch['pdb_ids'])
+
+ # Sample timestep
+ t = torch.rand(batch_size, device=self.device)
+
+ # Interpolate
+ t_expanded = t[ligand_batch.batch].unsqueeze(-1)
+ x_t = (1 - t_expanded) * coords_x0 + t_expanded * coords_x1
+
+ ligand_batch_t = ligand_batch.clone()
+ ligand_batch_t.pos = x_t
+
+ # Torsion data
+ torsion_data = batch.get('torsion_data')
+ if torsion_data is not None:
+ target = {
+ 'translation': torsion_data['translation'].to(self.device),
+ 'rotation': torsion_data['rotation'].to(self.device),
+ 'torsion_changes': torsion_data['torsion_changes'].to(self.device),
+ }
+ rotatable_edges = torsion_data['rotatable_edges'].to(self.device)
+ mask_rotate = torsion_data['mask_rotate'].to(self.device)
+ else:
+ # Fallback: rigid body only
+ from src.data.ligand_feat import compute_rigid_transform
+ translations, rotations = [], []
+ for b in range(batch_size):
+ mol_mask = (ligand_batch.batch == b)
+ trans, rot = compute_rigid_transform(coords_x0[mol_mask].cpu(), coords_x1[mol_mask].cpu())
+ translations.append(trans)
+ rotations.append(rot)
+ target = {
+ 'translation': torch.stack(translations).to(self.device),
+ 'rotation': torch.stack(rotations).to(self.device),
+ 'torsion_changes': torch.zeros(0, device=self.device),
+ }
+ rotatable_edges = torch.zeros(0, 2, dtype=torch.long, device=self.device)
+ mask_rotate = torch.zeros(0, coords_x0.shape[0], dtype=torch.bool, device=self.device)
+
+ # Forward
+ pred = self.model(protein_batch, ligand_batch_t, t, rotatable_edges=rotatable_edges)
+
+ # Loss
+ loss_config = self.config['training'].get('torsion_loss', {})
+ losses = compute_se3_torsion_loss(
+ pred=pred, target=target,
+ coords_x0=coords_x0, coords_x1=coords_x1,
+ mask_rotate=mask_rotate, rotatable_edges=rotatable_edges,
+ batch_indices=ligand_batch.batch,
+ w_trans=loss_config.get('w_trans', 1.0),
+ w_rot=loss_config.get('w_rot', 1.0),
+ w_tor=loss_config.get('w_tor', 1.0),
+ w_coord=loss_config.get('w_coord', 0.5),
+ )
+
+ loss = losses['total']
+
+ # Backward
+ grad_accum = self.config['training'].get('gradient_accumulation_steps', 1)
+ (loss / grad_accum).backward()
+
+ if (self.global_step + 1) % grad_accum == 0:
+ clip_val = self.config['training'].get('gradient_clip')
+ if clip_val:
+ nn.utils.clip_grad_norm_(self.model.parameters(), clip_val)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ with torch.no_grad():
+ rmsd = torch.sqrt(torch.mean((x_t - coords_x1) ** 2))
+
+ return {
+ 'loss': loss.item(),
+ 'rmsd': rmsd.item(),
+ 'loss_trans': losses['translation'].item(),
+ 'loss_rot': losses['rotation'].item(),
+ 'loss_tor': losses['torsion'].item(),
+ 'loss_coord': losses['coord_recon'].item(),
+ }
+
+ @torch.no_grad()
+ def validate(self):
+ """Validation with ODE sampling in torsion space."""
+ self.model.eval()
+
+ all_rmsds = []
+ all_initial_rmsds = []
+
+ num_steps = self.config['sampling'].get('num_steps', 20)
+ schedule = self.config['sampling'].get('schedule', 'uniform')
+ timesteps = generate_timestep_schedule(num_steps, schedule, self.device)
+
+ for batch in tqdm(self.val_loader, desc="Validation"):
+ ligand_batch = batch['ligand_graph'].to(self.device)
+ protein_batch = batch['protein_graph'].to(self.device)
+ coords_x0 = batch['ligand_coords_x0'].to(self.device)
+ coords_x1 = batch['ligand_coords_x1'].to(self.device)
+
+ # Initial RMSD
+ init_rmsd = torch.sqrt(torch.mean((coords_x0 - coords_x1) ** 2, dim=-1))
+ all_initial_rmsds.extend(init_rmsd.cpu().numpy())
+
+ # Torsion data
+ torsion_data = batch.get('torsion_data')
+ if torsion_data is not None:
+ rot_edges = torsion_data['rotatable_edges'].to(self.device)
+ mask_rot = torsion_data['mask_rotate'].to(self.device)
+ else:
+ rot_edges = torch.zeros(0, 2, dtype=torch.long, device=self.device)
+ mask_rot = torch.zeros(0, coords_x0.shape[0], dtype=torch.bool, device=self.device)
+
+ result = sample_trajectory_torsion(
+ model=self.model,
+ protein_batch=protein_batch,
+ ligand_batch=ligand_batch,
+ x0=coords_x0,
+ timesteps=timesteps,
+ rotatable_edges=rot_edges,
+ mask_rotate=mask_rot,
+ )
+
+ refined = result['final_coords']
+ per_sample_rmsd = torch.sqrt(torch.mean((refined - coords_x1) ** 2, dim=-1))
+ all_rmsds.extend(per_sample_rmsd.cpu().numpy())
+
+ # Metrics
+ rmsds = np.array(all_rmsds)
+ avg_rmsd = rmsds.mean()
+ avg_init_rmsd = np.mean(all_initial_rmsds)
+ success_2a = (rmsds < 2.0).mean() * 100
+ success_1a = (rmsds < 1.0).mean() * 100
+ success_05a = (rmsds < 0.5).mean() * 100
+
+ print(f"\n Validation Results:")
+ print(f" Initial RMSD: {avg_init_rmsd:.4f} A")
+ print(f" Final RMSD: {avg_rmsd:.4f} A")
+ print(f" Success <2A: {success_2a:.1f}% <1A: {success_1a:.1f}% <0.5A: {success_05a:.1f}%")
+
+ if self.wandb_enabled:
+ self.wandb_logger.log_validation_epoch(
+ val_loss=avg_rmsd, val_rmsd=avg_rmsd,
+ val_rmsd_initial=avg_init_rmsd, val_rmsd_final=avg_rmsd,
+ success_rate_2a=success_2a, success_rate_1a=success_1a,
+ success_rate_05a=success_05a, epoch=self.current_epoch,
+ )
+
+ # Early stopping
+ val_metrics = {
+ 'rmsd': avg_rmsd, 'success_2A': success_2a,
+ 'success_1A': success_1a, 'success_05A': success_05a,
+ }
+ early_stop = self.early_stopper.step(
+ score=-success_2a, model=self.model,
+ optimizer=self.optimizer, scheduler=self.scheduler,
+ epoch=self.current_epoch, valid_metrics=val_metrics,
+ )
+
+ if success_2a > self.best_val_success:
+ self.best_val_success = success_2a
+
+ self.model.train()
+ return avg_rmsd, early_stop
+
+ def save_checkpoint(self, filename):
+ """Save checkpoint."""
+ torch.save({
+ 'model_state_dict': self.model.state_dict(),
+ 'optimizer_state_dict': self.optimizer.state_dict(),
+ 'global_step': self.global_step,
+ 'current_epoch': self.current_epoch,
+ 'best_val_success': self.best_val_success,
+ 'config': self.config,
+ }, self.checkpoint_dir / filename)
+
+ def train(self):
+ """Main training loop."""
+ num_epochs = self.config['training']['num_epochs']
+ val_freq = self.config['training'].get('validation', {}).get('frequency', 20)
+
+ for epoch in range(num_epochs):
+ self.current_epoch = epoch
+
+ if hasattr(self.train_dataset, 'set_epoch'):
+ self.train_dataset.set_epoch(epoch)
+
+ self.model.train()
+ epoch_losses = []
+
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{num_epochs}")
+ for batch in pbar:
+ metrics = self.train_step(batch)
+ epoch_losses.append(metrics['loss'])
+ self.global_step += 1
+
+ pbar.set_postfix({
+ 'loss': f"{metrics['loss']:.4f}",
+ 'tr': f"{metrics['loss_trans']:.3f}",
+ 'rot': f"{metrics['loss_rot']:.3f}",
+ 'tor': f"{metrics['loss_tor']:.3f}",
+ })
+
+ if self.wandb_enabled and self.global_step % 10 == 0:
+ self.wandb_logger.log({
+ 'train/loss': metrics['loss'],
+ 'train/loss_trans': metrics['loss_trans'],
+ 'train/loss_rot': metrics['loss_rot'],
+ 'train/loss_tor': metrics['loss_tor'],
+ 'train/loss_coord': metrics['loss_coord'],
+ 'train/rmsd': metrics['rmsd'],
+ 'train/lr': self.optimizer.param_groups[0]['lr'],
+ 'meta/epoch': epoch,
+ 'meta/step': self.global_step,
+ })
+
+ # Validation
+ early_stop = False
+ if epoch > 0 and epoch % val_freq == 0:
+ _, early_stop = self.validate()
+
+ if early_stop:
+ print("\n Training stopped early")
+ break
+
+ # Checkpoints
+ save_freq = self.config['checkpoint'].get('save_freq', 10)
+ if epoch % save_freq == 0:
+ self.save_checkpoint(f'epoch_{epoch:04d}.pt')
+ if self.config['checkpoint'].get('save_latest', True):
+ self.save_checkpoint('latest.pt')
+
+ # Scheduler
+ if self.scheduler:
+ self.scheduler.step()
+
+ # Epoch summary
+ avg_loss = np.mean(epoch_losses)
+ lr = self.optimizer.param_groups[0]['lr']
+ print(f" Epoch {epoch}: loss={avg_loss:.4f} lr={lr:.6f} "
+ f"early_stop={self.early_stopper.counter}/{self.early_stopper.patience}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ parser.add_argument('--resume', type=str, default=None)
+ parser.add_argument('--device', type=str, default=None)
+ args = parser.parse_args()
+
+ with open(args.config, 'r') as f:
+ config = yaml.safe_load(f)
+
+ if args.device:
+ config['device'] = args.device
+
+ trainer = FlowFixTorsionTrainer(config)
+
+ if args.resume:
+ ckpt = torch.load(args.resume, weights_only=False)
+ trainer.model.load_state_dict(ckpt['model_state_dict'])
+ trainer.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
+ trainer.global_step = ckpt['global_step']
+ print(f"Resumed from step {trainer.global_step}")
+
+ try:
+ trainer.train()
+ finally:
+ if trainer.wandb_enabled:
+ wandb.finish()
+
+
+if __name__ == '__main__':
+ main()