Skip to content

Latest commit

 

History

History
834 lines (611 loc) · 17.1 KB

File metadata and controls

834 lines (611 loc) · 17.1 KB

fusGAN API Documentation

Complete API reference for fusGAN v2.0


Table of Contents

  1. Core Classes
  2. Training Utilities
  3. Evaluation Utilities
  4. Dataset Loaders
  5. Metrics
  6. Utility Functions

Core Classes

GAN

Main GAN class handling model training and inference.

Location: models.py

Constructor

GAN(generator: tf.keras.Model = None,
    discriminator: tf.keras.Model = None,
    generator_optimizer: tf.keras.optimizers.Optimizer = None,
    discriminator_optimizer: tf.keras.optimizers.Optimizer = None,
    loss_object: tf.keras.losses.Loss = None)

Parameters:

  • generator: Generator model (U-Net architecture)
  • discriminator: Discriminator model (PatchGAN)
  • generator_optimizer: Optimizer for generator
  • discriminator_optimizer: Optimizer for discriminator
  • loss_object: Loss function (typically MSE or BCE)

Methods

train_step()
train_step(input_image: tf.Tensor,
           target: tf.Tensor,
           train_discriminator: bool = True,
           train_generator: bool = True) -> Tuple

Single training step with optional selective training.

Parameters:

  • input_image: Input tensor (batch_size, 128, 128, 2)
  • target: Target tensor (batch_size, 128, 128, 1)
  • train_discriminator: Whether to update discriminator
  • train_generator: Whether to update generator

Returns: Tuple of (gen_loss, disc_loss, l1_loss, real_acc, gen_acc)

Example:

for input, target in dataset:
    gen_loss, disc_loss, l1, r_acc, g_acc = gan.train_step(input, target)
fit()
fit(train_dataset: tf.data.Dataset,
    epochs: int,
    disc_steps: int = 1,
    gen_steps: int = 1) -> Dict[str, List]

Training loop with alternating discriminator/generator updates.

Parameters:

  • train_dataset: TensorFlow dataset
  • epochs: Number of training epochs
  • disc_steps: Consecutive discriminator training steps per cycle
  • gen_steps: Consecutive generator training steps per cycle

Returns: Dictionary with training history

Example:

history = gan.fit(
    train_dataset,
    epochs=150,
    disc_steps=1,
    gen_steps=3
)
# history = {'generator_loss': [...], 'discriminator_loss': [...], ...}
evaluate()
evaluate(test_dataset: tf.data.Dataset) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, np.ndarray, float]

Evaluate model on test dataset.

Returns: (mse, psnr, ssim, lpips, eval_time)

Example:

mse, psnr, ssim, lpips, time = gan.evaluate(test_dataset)
print(f"Average PSNR: {tf.reduce_mean(psnr):.2f} dB")
save_model() / load_model()
save_model(generator_path: str, discriminator_path: str)
load_model(generator_path: str, discriminator_path: str)

Save/load models in Keras format.

Example:

# Save
gan.save_model('models/gen.keras', 'models/disc.keras')

# Load
gan.load_model('models/gen.keras', 'models/disc.keras')

Generator

U-Net generator with skip connections.

Location: models.py

Constructor

Generator(input_channels: int = 2,
          output_channels: int = 1)

Parameters:

  • input_channels: Number of input channels (default: 2 for CT+Mask)
  • output_channels: Number of output channels (default: 1)

Attributes:

  • model: Keras Model instance

Architecture:

  • 5 downsampling blocks (64 → 512 filters)
  • Bottleneck (512 filters)
  • 5 upsampling blocks with skip connections
  • Output: Tanh activation

Example:

gen = Generator(2, 1)
output = gen.model(input_tensor, training=True)

Variants:

  • BigGenerator: Same structure, more filters (up to 2048)
  • HugeGenerator: Deeper (7 down, 5 up), up to 8192 filters

Discriminator

PatchGAN discriminator.

Location: models.py

Constructor

Discriminator(input_channels: int = 2)

Parameters:

  • input_channels: Number of input channels

Architecture:

  • 70x70 PatchGAN
  • 4 convolutional blocks
  • Output: 30x30x1 predictions

Example:

disc = Discriminator(2)
prediction = disc.model([input, target], training=True)

Variants:

  • SmallDiscriminator: Fewer parameters, faster training

Training Utilities

GANTrainer

Advanced trainer with learning rate scheduling and early stopping.

Location: src/trainer.py

Constructor

GANTrainer(gan: GAN,
           lr_schedule: Optional[tf.keras.optimizers.schedules.LearningRateSchedule] = None,
           early_stopping: Optional[EarlyStopping] = None,
           checkpoint_dir: Optional[str] = None,
           save_freq: int = 10)

Parameters:

  • gan: GAN instance
  • lr_schedule: Learning rate schedule (optional)
  • early_stopping: EarlyStopping callback (optional)
  • checkpoint_dir: Directory for checkpoints (optional)
  • save_freq: Save checkpoint every N epochs

Example:

trainer = GANTrainer(
    gan=gan,
    early_stopping=EarlyStopping(patience=15),
    checkpoint_dir='checkpoints',
    save_freq=10
)

Methods

fit()
fit(train_dataset: tf.data.Dataset,
    epochs: int,
    validation_dataset: Optional[tf.data.Dataset] = None,
    disc_steps: int = 1,
    gen_steps: int = 1,
    callbacks: Optional[List[Callable]] = None) -> Dict[str, List]

Train with validation and callbacks.

Parameters:

  • validation_dataset: Optional validation set for early stopping
  • callbacks: List of callback functions callback(epoch, history)

Returns: Training history with additional fields:

  • val_loss: Validation losses
  • lr_gen: Generator learning rates
  • lr_disc: Discriminator learning rates

Example:

def print_lr(epoch, history):
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1} - LR: {history['lr_gen'][-1]:.6f}")

history = trainer.fit(
    train_dataset=train_ds,
    epochs=150,
    validation_dataset=val_ds,
    callbacks=[print_lr]
)

EarlyStopping

Stop training when validation metric plateaus.

Location: src/trainer.py

Constructor

EarlyStopping(patience: int = 10,
              min_delta: float = 1e-4,
              mode: str = 'min')

Parameters:

  • patience: Number of epochs to wait before stopping
  • min_delta: Minimum change to qualify as improvement
  • mode: 'min' for loss (lower is better), 'max' for accuracy

Example:

early_stop = EarlyStopping(
    patience=20,
    min_delta=1e-4,
    mode='min'
)

trainer = GANTrainer(gan, early_stopping=early_stop)

Methods

check()
check(current_value: float) -> bool

Check if training should stop.

Returns: True if should stop, False otherwise


LearningRateScheduler

Learning rate scheduling utilities.

Location: src/trainer.py

Static Methods

exponential_decay()
@staticmethod
exponential_decay(initial_lr: float,
                 decay_rate: float = 0.96,
                 decay_steps: int = 1000)

Create exponential decay schedule.

Example:

schedule = LearningRateScheduler.exponential_decay(
    initial_lr=2e-4,
    decay_rate=0.96,
    decay_steps=1000
)
optimizer = tf.keras.optimizers.Adam(schedule, beta_1=0.5)
cosine_decay()
@staticmethod
cosine_decay(initial_lr: float,
            total_steps: int,
            alpha: float = 0.0)

Create cosine decay schedule (smooth decay).

Example:

schedule = LearningRateScheduler.cosine_decay(
    initial_lr=2e-4,
    total_steps=10000,
    alpha=0.1  # Min LR = 10% of initial
)
step_decay()
@staticmethod
step_decay(initial_lr: float,
          drop_rate: float = 0.5,
          epochs_drop: int = 25)

Step-wise LR reduction.

Returns: Keras LearningRateScheduler callback


Evaluation Utilities

GANEvaluator

Comprehensive evaluation and visualization.

Location: src/evaluator.py

Constructor

GANEvaluator(gan: GAN)

Methods

evaluate()
evaluate(test_dataset: tf.data.Dataset) -> Tuple

Compute all metrics on test set.

Returns: (mse, psnr, ssim, lpips, eval_time)

Example:

evaluator = GANEvaluator(gan)
mse, psnr, ssim, lpips, time = evaluator.evaluate(test_dataset)
evaluate_and_visualize()
evaluate_and_visualize(test_dataset: tf.data.Dataset,
                       output_dir: str,
                       num_samples: int = 5) -> Dict[str, np.ndarray]

Full evaluation with visualizations.

Creates:

  • {output_dir}/visualizations/generated_image_{i}.png
  • {output_dir}/metrics_distribution.png

Returns: Dictionary with metric arrays

Example:

metrics = evaluator.evaluate_and_visualize(
    test_dataset,
    output_dir='results',
    num_samples=10
)
# metrics = {'mse': [...], 'psnr': [...], ...}
compare_models()
compare_models(test_dataset: tf.data.Dataset,
               model_paths: Dict[str, str],
               output_dir: str) -> Dict[str, Dict]

Compare multiple models side-by-side.

Example:

results = evaluator.compare_models(
    test_dataset,
    model_paths={
        'Baseline': 'models/baseline.keras',
        'Advanced': 'models/advanced.keras'
    },
    output_dir='comparison'
)
# Creates comparison plots automatically

Dataset Loaders

MatDataset

Load data from MATLAB .mat files.

Location: src/dataset.py

Constructor

MatDataset(filenames: list,
           path: str = None,
           reshape: tuple = RESHAPE,
           batch_size: int = BATCH_SIZE,
           shuffle: bool = SHUFFLE,
           augment: bool = AUGMENT,
           resize: list = RESIZE,
           normalize: str = NORMALIZATION)

Parameters:

  • filenames: List of .mat file paths
  • batch_size: Batch size (default: 32)
  • shuffle: Shuffle dataset (default: True)
  • augment: Apply data augmentation (default: True)
  • resize: Target size [H, W] (default: [128, 128])
  • normalize: Normalization method (default: 'negpos')

Expected .mat Format:

mask: [172 x N] - Transducer mask
density: [172 x N] - CT density
PII: [172 x N] - Ultrasound simulation (target)

Example:

filenames = ['data/train/sample001.mat', 'data/train/sample002.mat']
dataset = MatDataset(filenames, batch_size=32, augment=True)

for input, target in dataset.dataset:
    # input: (32, 128, 128, 2) - CT + Mask
    # target: (32, 128, 128, 1) - Ultrasound
    pass

ImDataset

Load data from PNG/JPG image files.

Location: src/dataset.py

Constructor

ImDataset(path: str,
          filenames: list = None,
          reshape: tuple = RESHAPE,
          batch_size: int = BATCH_SIZE,
          shuffle: bool = SHUFFLE,
          augment: bool = AUGMENT,
          resize: list = RESIZE,
          normalize: str = NORMALIZATION)

Parameters:

  • path: Root directory containing subdirectories
  • filenames: Optional list of (ct, mask, sim) tuples

Expected Directory Structure:

path/
├── ct_slices/
│   ├── img001.png
│   └── img002.png
├── tr_masks/
│   ├── img001.png
│   └── img002.png
└── pi_maps/
    ├── img001.png
    └── img002.png

Example:

# Auto-discover files
dataset = ImDataset(path='data/train', batch_size=16)

# Or specify files manually
dataset = ImDataset(
    filenames=[
        ('ct.png', 'mask.png', 'sim.png'),
        ('ct2.png', 'mask2.png', 'sim2.png')
    ],
    batch_size=2
)

Metrics

All metrics operate on batches and return per-sample values.

Location: src/metrics.py

MSE()

MSE(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor

Mean Squared Error.

Returns: Shape (batch_size,)

Example:

mse = MSE(targets, predictions)
avg_mse = tf.reduce_mean(mse)

PSNR()

PSNR(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor

Peak Signal-to-Noise Ratio (in dB).

Range: Higher is better (typically 20-40 dB)

Example:

psnr = PSNR(targets, predictions)
print(f"Average PSNR: {tf.reduce_mean(psnr):.2f} dB")

SSIM()

SSIM(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor

Structural Similarity Index.

Range: [0, 1], higher is better

Example:

ssim = SSIM(targets, predictions)
print(f"Average SSIM: {tf.reduce_mean(ssim):.4f}")

LPIPS()

LPIPS(y_true: tf.Tensor, y_pred: tf.Tensor) -> np.ndarray

Learned Perceptual Image Patch Similarity.

Note: Uses cached LPIPS model for performance.

Range: [0, 1], lower is better

Performance:

  • First call: ~10s (model loading)
  • Subsequent calls: ~0.1s per batch

Example:

lpips = LPIPS(targets, predictions)
print(f"Average LPIPS: {np.mean(lpips):.4f}")

Utility Functions

normalize_tensor()

Location: src/utils.py

normalize_tensor(tensor: tf.Tensor,
                method: str = 'standard',
                epsilon: float = 1e-8) -> tf.Tensor

Normalize tensor with division-by-zero protection.

Methods:

  • 'standard': Zero mean, unit variance (z-score)
  • 'minmax': Scale to [0, 1]
  • 'negpos': Scale to [-1, 1]

Example:

from src.utils import normalize_tensor

normalized = normalize_tensor(image, method='negpos')

Protection: Adds epsilon to denominators to prevent NaN/Inf.


Configuration

Location: src/config.py

All configuration variables with environment variable support:

# Data Paths
DATA_ROOT = os.getenv('FUSGAN_DATA_ROOT', 'data')
TRAIN_DATA_DIR = os.path.join(DATA_ROOT, 'train')
TEST_DATA_DIR = os.path.join(DATA_ROOT, 'test')
SAVED_MODELS_DIR = os.getenv('FUSGAN_MODELS_DIR', 'SavedModels')

# Training
BATCH_SIZE = 32
TRAIN_EPOCHS = 150
LEARNING_RATE = 2e-4
DISCRIMINATOR_LR = 2e-5
BETA_1 = 0.5
LAMBDA = 100

# Model
INPUT_CHANNELS = 2
OUTPUT_CHANNELS = 1
IMAGE_SIZE = (128, 128)

# Data Processing
NORMALIZATION = 'negpos'
AUGMENT = True
SHUFFLE = True

Complete Training Example

Putting it all together:

import os
os.environ['MPLBACKEND'] = 'Agg'

import tensorflow as tf
from models import GAN, Generator, SmallDiscriminator
from src.dataset import MatDataset
from src.config import *
from src.trainer import GANTrainer, EarlyStopping, LearningRateScheduler
from src.evaluator import GANEvaluator

# 1. Load data
train_files = ['data/train/sample{:03d}.mat'.format(i) for i in range(800)]
val_files = ['data/train/sample{:03d}.mat'.format(i) for i in range(800, 1000)]
test_files = ['data/test/sample{:03d}.mat'.format(i) for i in range(100)]

train_ds = MatDataset(train_files).dataset
val_ds = MatDataset(val_files, shuffle=False, augment=False).dataset
test_ds = MatDataset(test_files, batch_size=100, shuffle=False, augment=False).dataset

# 2. Create models
generator = Generator(INPUT_CHANNELS, OUTPUT_CHANNELS).model
discriminator = SmallDiscriminator(INPUT_CHANNELS).model

# 3. Setup LR schedules
steps_per_epoch = len(train_files) // BATCH_SIZE
gen_schedule = LearningRateScheduler.cosine_decay(
    LEARNING_RATE, steps_per_epoch * TRAIN_EPOCHS, alpha=0.1
)
disc_schedule = LearningRateScheduler.exponential_decay(
    DISCRIMINATOR_LR, decay_rate=0.96, decay_steps=steps_per_epoch * 5
)

# 4. Initialize GAN
gen_opt = tf.keras.optimizers.Adam(gen_schedule, beta_1=BETA_1)
disc_opt = tf.keras.optimizers.Adam(disc_schedule, beta_1=BETA_1)
gan = GAN(generator, discriminator, gen_opt, disc_opt, tf.keras.losses.MeanSquaredError())

# 5. Setup trainer
trainer = GANTrainer(
    gan=gan,
    early_stopping=EarlyStopping(patience=15),
    checkpoint_dir='checkpoints',
    save_freq=10
)

# 6. Train
history = trainer.fit(
    train_dataset=train_ds,
    epochs=TRAIN_EPOCHS,
    validation_dataset=val_ds,
    disc_steps=1,
    gen_steps=3
)

# 7. Evaluate
evaluator = GANEvaluator(gan)
metrics = evaluator.evaluate_and_visualize(test_ds, 'results', num_samples=10)

# 8. Save
gan.save_model('models/generator_final.keras', 'models/discriminator_final.keras')

print(f"Training complete!")
print(f"Best validation loss: {min(history['val_loss']):.6f}")
print(f"Final PSNR: {np.mean(metrics['psnr']):.2f} dB")

Error Handling

All classes include proper error handling:

# Dataset validation
dataset = MatDataset([])  # ValueError: filenames cannot be None or empty

# File validation
dataset.readFile('nonexistent.mat')  # FileNotFoundError with clear message

# Type validation
normalize_tensor(image, method='invalid')  # ValueError: Unknown normalization method

Version History

v2.0.0 (November 2024)

  • Refactored architecture with GANTrainer and GANEvaluator
  • Added learning rate scheduling
  • Added early stopping
  • Cached LPIPS for 10-100x speedup
  • Centralized configuration
  • Division-by-zero protection
  • Comprehensive error handling

v1.0.0 (Initial Release)

  • Basic GAN training
  • Manual training loops
  • Basic evaluation

Support

For API questions or issues: