Complete API reference for fusGAN v2.0
Main GAN class handling model training and inference.
Location: models.py
GAN(generator: tf.keras.Model = None,
discriminator: tf.keras.Model = None,
generator_optimizer: tf.keras.optimizers.Optimizer = None,
discriminator_optimizer: tf.keras.optimizers.Optimizer = None,
loss_object: tf.keras.losses.Loss = None)Parameters:
generator: Generator model (U-Net architecture)discriminator: Discriminator model (PatchGAN)generator_optimizer: Optimizer for generatordiscriminator_optimizer: Optimizer for discriminatorloss_object: Loss function (typically MSE or BCE)
train_step(input_image: tf.Tensor,
target: tf.Tensor,
train_discriminator: bool = True,
train_generator: bool = True) -> TupleSingle training step with optional selective training.
Parameters:
input_image: Input tensor (batch_size, 128, 128, 2)target: Target tensor (batch_size, 128, 128, 1)train_discriminator: Whether to update discriminatortrain_generator: Whether to update generator
Returns: Tuple of (gen_loss, disc_loss, l1_loss, real_acc, gen_acc)
Example:
for input, target in dataset:
gen_loss, disc_loss, l1, r_acc, g_acc = gan.train_step(input, target)fit(train_dataset: tf.data.Dataset,
epochs: int,
disc_steps: int = 1,
gen_steps: int = 1) -> Dict[str, List]Training loop with alternating discriminator/generator updates.
Parameters:
train_dataset: TensorFlow datasetepochs: Number of training epochsdisc_steps: Consecutive discriminator training steps per cyclegen_steps: Consecutive generator training steps per cycle
Returns: Dictionary with training history
Example:
history = gan.fit(
train_dataset,
epochs=150,
disc_steps=1,
gen_steps=3
)
# history = {'generator_loss': [...], 'discriminator_loss': [...], ...}evaluate(test_dataset: tf.data.Dataset) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, np.ndarray, float]Evaluate model on test dataset.
Returns: (mse, psnr, ssim, lpips, eval_time)
Example:
mse, psnr, ssim, lpips, time = gan.evaluate(test_dataset)
print(f"Average PSNR: {tf.reduce_mean(psnr):.2f} dB")save_model(generator_path: str, discriminator_path: str)
load_model(generator_path: str, discriminator_path: str)Save/load models in Keras format.
Example:
# Save
gan.save_model('models/gen.keras', 'models/disc.keras')
# Load
gan.load_model('models/gen.keras', 'models/disc.keras')U-Net generator with skip connections.
Location: models.py
Generator(input_channels: int = 2,
output_channels: int = 1)Parameters:
input_channels: Number of input channels (default: 2 for CT+Mask)output_channels: Number of output channels (default: 1)
Attributes:
model: Keras Model instance
Architecture:
- 5 downsampling blocks (64 → 512 filters)
- Bottleneck (512 filters)
- 5 upsampling blocks with skip connections
- Output: Tanh activation
Example:
gen = Generator(2, 1)
output = gen.model(input_tensor, training=True)Variants:
BigGenerator: Same structure, more filters (up to 2048)HugeGenerator: Deeper (7 down, 5 up), up to 8192 filters
PatchGAN discriminator.
Location: models.py
Discriminator(input_channels: int = 2)Parameters:
input_channels: Number of input channels
Architecture:
- 70x70 PatchGAN
- 4 convolutional blocks
- Output: 30x30x1 predictions
Example:
disc = Discriminator(2)
prediction = disc.model([input, target], training=True)Variants:
SmallDiscriminator: Fewer parameters, faster training
Advanced trainer with learning rate scheduling and early stopping.
Location: src/trainer.py
GANTrainer(gan: GAN,
lr_schedule: Optional[tf.keras.optimizers.schedules.LearningRateSchedule] = None,
early_stopping: Optional[EarlyStopping] = None,
checkpoint_dir: Optional[str] = None,
save_freq: int = 10)Parameters:
gan: GAN instancelr_schedule: Learning rate schedule (optional)early_stopping: EarlyStopping callback (optional)checkpoint_dir: Directory for checkpoints (optional)save_freq: Save checkpoint every N epochs
Example:
trainer = GANTrainer(
gan=gan,
early_stopping=EarlyStopping(patience=15),
checkpoint_dir='checkpoints',
save_freq=10
)fit(train_dataset: tf.data.Dataset,
epochs: int,
validation_dataset: Optional[tf.data.Dataset] = None,
disc_steps: int = 1,
gen_steps: int = 1,
callbacks: Optional[List[Callable]] = None) -> Dict[str, List]Train with validation and callbacks.
Parameters:
validation_dataset: Optional validation set for early stoppingcallbacks: List of callback functionscallback(epoch, history)
Returns: Training history with additional fields:
val_loss: Validation losseslr_gen: Generator learning rateslr_disc: Discriminator learning rates
Example:
def print_lr(epoch, history):
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1} - LR: {history['lr_gen'][-1]:.6f}")
history = trainer.fit(
train_dataset=train_ds,
epochs=150,
validation_dataset=val_ds,
callbacks=[print_lr]
)Stop training when validation metric plateaus.
Location: src/trainer.py
EarlyStopping(patience: int = 10,
min_delta: float = 1e-4,
mode: str = 'min')Parameters:
patience: Number of epochs to wait before stoppingmin_delta: Minimum change to qualify as improvementmode: 'min' for loss (lower is better), 'max' for accuracy
Example:
early_stop = EarlyStopping(
patience=20,
min_delta=1e-4,
mode='min'
)
trainer = GANTrainer(gan, early_stopping=early_stop)check(current_value: float) -> boolCheck if training should stop.
Returns: True if should stop, False otherwise
Learning rate scheduling utilities.
Location: src/trainer.py
@staticmethod
exponential_decay(initial_lr: float,
decay_rate: float = 0.96,
decay_steps: int = 1000)Create exponential decay schedule.
Example:
schedule = LearningRateScheduler.exponential_decay(
initial_lr=2e-4,
decay_rate=0.96,
decay_steps=1000
)
optimizer = tf.keras.optimizers.Adam(schedule, beta_1=0.5)@staticmethod
cosine_decay(initial_lr: float,
total_steps: int,
alpha: float = 0.0)Create cosine decay schedule (smooth decay).
Example:
schedule = LearningRateScheduler.cosine_decay(
initial_lr=2e-4,
total_steps=10000,
alpha=0.1 # Min LR = 10% of initial
)@staticmethod
step_decay(initial_lr: float,
drop_rate: float = 0.5,
epochs_drop: int = 25)Step-wise LR reduction.
Returns: Keras LearningRateScheduler callback
Comprehensive evaluation and visualization.
Location: src/evaluator.py
GANEvaluator(gan: GAN)evaluate(test_dataset: tf.data.Dataset) -> TupleCompute all metrics on test set.
Returns: (mse, psnr, ssim, lpips, eval_time)
Example:
evaluator = GANEvaluator(gan)
mse, psnr, ssim, lpips, time = evaluator.evaluate(test_dataset)evaluate_and_visualize(test_dataset: tf.data.Dataset,
output_dir: str,
num_samples: int = 5) -> Dict[str, np.ndarray]Full evaluation with visualizations.
Creates:
{output_dir}/visualizations/generated_image_{i}.png{output_dir}/metrics_distribution.png
Returns: Dictionary with metric arrays
Example:
metrics = evaluator.evaluate_and_visualize(
test_dataset,
output_dir='results',
num_samples=10
)
# metrics = {'mse': [...], 'psnr': [...], ...}compare_models(test_dataset: tf.data.Dataset,
model_paths: Dict[str, str],
output_dir: str) -> Dict[str, Dict]Compare multiple models side-by-side.
Example:
results = evaluator.compare_models(
test_dataset,
model_paths={
'Baseline': 'models/baseline.keras',
'Advanced': 'models/advanced.keras'
},
output_dir='comparison'
)
# Creates comparison plots automaticallyLoad data from MATLAB .mat files.
Location: src/dataset.py
MatDataset(filenames: list,
path: str = None,
reshape: tuple = RESHAPE,
batch_size: int = BATCH_SIZE,
shuffle: bool = SHUFFLE,
augment: bool = AUGMENT,
resize: list = RESIZE,
normalize: str = NORMALIZATION)Parameters:
filenames: List of .mat file pathsbatch_size: Batch size (default: 32)shuffle: Shuffle dataset (default: True)augment: Apply data augmentation (default: True)resize: Target size [H, W] (default: [128, 128])normalize: Normalization method (default: 'negpos')
Expected .mat Format:
mask: [172 x N] - Transducer mask
density: [172 x N] - CT density
PII: [172 x N] - Ultrasound simulation (target)Example:
filenames = ['data/train/sample001.mat', 'data/train/sample002.mat']
dataset = MatDataset(filenames, batch_size=32, augment=True)
for input, target in dataset.dataset:
# input: (32, 128, 128, 2) - CT + Mask
# target: (32, 128, 128, 1) - Ultrasound
passLoad data from PNG/JPG image files.
Location: src/dataset.py
ImDataset(path: str,
filenames: list = None,
reshape: tuple = RESHAPE,
batch_size: int = BATCH_SIZE,
shuffle: bool = SHUFFLE,
augment: bool = AUGMENT,
resize: list = RESIZE,
normalize: str = NORMALIZATION)Parameters:
path: Root directory containing subdirectoriesfilenames: Optional list of (ct, mask, sim) tuples
Expected Directory Structure:
path/
├── ct_slices/
│ ├── img001.png
│ └── img002.png
├── tr_masks/
│ ├── img001.png
│ └── img002.png
└── pi_maps/
├── img001.png
└── img002.png
Example:
# Auto-discover files
dataset = ImDataset(path='data/train', batch_size=16)
# Or specify files manually
dataset = ImDataset(
filenames=[
('ct.png', 'mask.png', 'sim.png'),
('ct2.png', 'mask2.png', 'sim2.png')
],
batch_size=2
)All metrics operate on batches and return per-sample values.
Location: src/metrics.py
MSE(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.TensorMean Squared Error.
Returns: Shape (batch_size,)
Example:
mse = MSE(targets, predictions)
avg_mse = tf.reduce_mean(mse)PSNR(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.TensorPeak Signal-to-Noise Ratio (in dB).
Range: Higher is better (typically 20-40 dB)
Example:
psnr = PSNR(targets, predictions)
print(f"Average PSNR: {tf.reduce_mean(psnr):.2f} dB")SSIM(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.TensorStructural Similarity Index.
Range: [0, 1], higher is better
Example:
ssim = SSIM(targets, predictions)
print(f"Average SSIM: {tf.reduce_mean(ssim):.4f}")LPIPS(y_true: tf.Tensor, y_pred: tf.Tensor) -> np.ndarrayLearned Perceptual Image Patch Similarity.
Note: Uses cached LPIPS model for performance.
Range: [0, 1], lower is better
Performance:
- First call: ~10s (model loading)
- Subsequent calls: ~0.1s per batch
Example:
lpips = LPIPS(targets, predictions)
print(f"Average LPIPS: {np.mean(lpips):.4f}")Location: src/utils.py
normalize_tensor(tensor: tf.Tensor,
method: str = 'standard',
epsilon: float = 1e-8) -> tf.TensorNormalize tensor with division-by-zero protection.
Methods:
'standard': Zero mean, unit variance (z-score)'minmax': Scale to [0, 1]'negpos': Scale to [-1, 1]
Example:
from src.utils import normalize_tensor
normalized = normalize_tensor(image, method='negpos')Protection: Adds epsilon to denominators to prevent NaN/Inf.
Location: src/config.py
All configuration variables with environment variable support:
# Data Paths
DATA_ROOT = os.getenv('FUSGAN_DATA_ROOT', 'data')
TRAIN_DATA_DIR = os.path.join(DATA_ROOT, 'train')
TEST_DATA_DIR = os.path.join(DATA_ROOT, 'test')
SAVED_MODELS_DIR = os.getenv('FUSGAN_MODELS_DIR', 'SavedModels')
# Training
BATCH_SIZE = 32
TRAIN_EPOCHS = 150
LEARNING_RATE = 2e-4
DISCRIMINATOR_LR = 2e-5
BETA_1 = 0.5
LAMBDA = 100
# Model
INPUT_CHANNELS = 2
OUTPUT_CHANNELS = 1
IMAGE_SIZE = (128, 128)
# Data Processing
NORMALIZATION = 'negpos'
AUGMENT = True
SHUFFLE = TruePutting it all together:
import os
os.environ['MPLBACKEND'] = 'Agg'
import tensorflow as tf
from models import GAN, Generator, SmallDiscriminator
from src.dataset import MatDataset
from src.config import *
from src.trainer import GANTrainer, EarlyStopping, LearningRateScheduler
from src.evaluator import GANEvaluator
# 1. Load data
train_files = ['data/train/sample{:03d}.mat'.format(i) for i in range(800)]
val_files = ['data/train/sample{:03d}.mat'.format(i) for i in range(800, 1000)]
test_files = ['data/test/sample{:03d}.mat'.format(i) for i in range(100)]
train_ds = MatDataset(train_files).dataset
val_ds = MatDataset(val_files, shuffle=False, augment=False).dataset
test_ds = MatDataset(test_files, batch_size=100, shuffle=False, augment=False).dataset
# 2. Create models
generator = Generator(INPUT_CHANNELS, OUTPUT_CHANNELS).model
discriminator = SmallDiscriminator(INPUT_CHANNELS).model
# 3. Setup LR schedules
steps_per_epoch = len(train_files) // BATCH_SIZE
gen_schedule = LearningRateScheduler.cosine_decay(
LEARNING_RATE, steps_per_epoch * TRAIN_EPOCHS, alpha=0.1
)
disc_schedule = LearningRateScheduler.exponential_decay(
DISCRIMINATOR_LR, decay_rate=0.96, decay_steps=steps_per_epoch * 5
)
# 4. Initialize GAN
gen_opt = tf.keras.optimizers.Adam(gen_schedule, beta_1=BETA_1)
disc_opt = tf.keras.optimizers.Adam(disc_schedule, beta_1=BETA_1)
gan = GAN(generator, discriminator, gen_opt, disc_opt, tf.keras.losses.MeanSquaredError())
# 5. Setup trainer
trainer = GANTrainer(
gan=gan,
early_stopping=EarlyStopping(patience=15),
checkpoint_dir='checkpoints',
save_freq=10
)
# 6. Train
history = trainer.fit(
train_dataset=train_ds,
epochs=TRAIN_EPOCHS,
validation_dataset=val_ds,
disc_steps=1,
gen_steps=3
)
# 7. Evaluate
evaluator = GANEvaluator(gan)
metrics = evaluator.evaluate_and_visualize(test_ds, 'results', num_samples=10)
# 8. Save
gan.save_model('models/generator_final.keras', 'models/discriminator_final.keras')
print(f"Training complete!")
print(f"Best validation loss: {min(history['val_loss']):.6f}")
print(f"Final PSNR: {np.mean(metrics['psnr']):.2f} dB")All classes include proper error handling:
# Dataset validation
dataset = MatDataset([]) # ValueError: filenames cannot be None or empty
# File validation
dataset.readFile('nonexistent.mat') # FileNotFoundError with clear message
# Type validation
normalize_tensor(image, method='invalid') # ValueError: Unknown normalization methodv2.0.0 (November 2024)
- Refactored architecture with
GANTrainerandGANEvaluator - Added learning rate scheduling
- Added early stopping
- Cached LPIPS for 10-100x speedup
- Centralized configuration
- Division-by-zero protection
- Comprehensive error handling
v1.0.0 (Initial Release)
- Basic GAN training
- Manual training loops
- Basic evaluation
For API questions or issues:
- GitHub Issues: https://github.com/aconesac/fusGAN/issues
- Email: aconesa@researchmar.net