Skip to content

OptiTrainer: Boost your deep learning model accuracy by up to 7% with an advanced, plug-and-play PyTorch training pipeline. Perfect for beginners, no complex setup needed.

License

Notifications You must be signed in to change notification settings

Jackksonns/OptiTrainer

Repository files navigation

OptiTrainer: Optimized Training Pipeline for Deep Learning

GitHub

Overview

OptiTrainer is an advanced training framework designed to enhance the performance and robustness of deep learning models, with a primary focus on image classification tasks. This framework integrates multiple advanced training strategies, demonstrating significant accuracy improvements (up to 7%) on standard datasets such as CIFAR-10 compared to conventional training approaches.

Key Features

  • Standard Training Pipeline (normal_train.py): A baseline implementation providing a straightforward training and validation workflow for performance comparison.
  • Optimized Training Pipeline (train.py): Implements two powerful ensemble strategies:
    • K-Fold Cross Validation with Diversity: Enhances model generalization through stratified data splitting and diverse data augmentation across folds
    • Snapshot Ensemble: Generates multiple model snapshots during a single training run with cosine annealing learning rate scheduling
  • Advanced Training Components: Cosine annealing learning rate scheduling, automatic best model checkpointing, and soft-voting ensemble inference
  • Customizable Architecture: Flexible design supporting user-defined models and datasets

Quick Start

1. Clone the Repository

# Clone OptiTrainer
git clone https://github.com/Jackksonns/OptiTrainer.git
cd OptiTrainer

2. Data Preparation

  • CIFAR-10 dataset is used by default and will be automatically downloaded on first run
  • For custom datasets, organize data in PyTorch-compatible format and modify the data loading section in train.py and normal_train.py

3. Run Training Scripts

  • Standard training (baseline):

    python normal_train.py
  • Optimized training (K-Fold by default):

    python train.py

Detailed Feature Description

Core Components of the Optimized Pipeline (train.py)

1. K-Fold Cross Validation with Diversity

# K-Fold implementation snippet
def get_stratified_kfold_indices(labels, n_splits=5, random_state=SEED, shuffle=True):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
    labels = np.array(labels)
    train_idxs_list, val_idxs_list = [], []
    for tr, vl in skf.split(np.zeros(len(labels)), labels):
        train_idxs_list.append(tr)
        val_idxs_list.append(vl)
    return train_idxs_list, val_idxs_list
  • Stratified Splitting: Maintains class distribution across folds using StratifiedKFold from scikit-learn
  • Diversity Enhancement: Each fold uses different random seeds and data augmentation strategies
  • Per-Fold Optimization: Independent training with automatic best model selection for each fold

2. Snapshot Ensemble

  • Single training run with periodic model checkpoints (snapshots) at specified intervals
  • Leverages cosine annealing learning rate scheduling to encourage diverse model states
  • Efficient alternative to traditional cross-validation with reduced computational overhead

3. Ensemble Inference

# Ensemble prediction implementation
stacked = torch.stack(all_fold_preds, dim=0)  # [k, N, C] (if logits, it's logits)
# Weighted averaging based on validation accuracy
if use_weighted:
    w = np.array(best_acc_list, dtype=float)
    w = w / w.sum()
    w_t = torch.tensor(w, dtype=stacked.dtype).view(-1,1,1)
    final_probs = (stacked * w_t).sum(dim=0)
else:
    final_probs = torch.mean(stacked, dim=0)
  • Soft-Voting: Aggregates predictions from multiple models using probability averaging
  • Weighted Ensemble: Option to weight models by their validation performance
  • Configurable Input: Supports averaging of probabilities or logits

Main Parameter Explanations

Training Configuration

  • method: Training strategy ('kfold_diverse' or 'snapshot')
  • SEED: Random seed for reproducibility
  • num_epochs: Number of training epochs per fold or for snapshot training
  • init_lr: Initial learning rate
  • n_splits: Number of folds for cross-validation

Ensemble Options

  • use_weighted: Whether to use validation accuracy for weighted ensemble
  • ensemble_from: Input type for ensemble ('probs' or 'logits')
  • snap_interval: Interval between snapshots (for snapshot method)

Experimental Results

  • On CIFAR-10 dataset, OptiTrainer demonstrates up to 7% higher accuracy (72% vs. 65%) compared to standard training pipelines with the same model architecture and training duration(with same init lr and epochs)

  • #mormal training process  --result
    step: 153000, train loss: 0.7846522927284241
    epoch 195 | test loss: 1.1244, test acc: 0.6547
    
    #OptiTrainer training process  --result
    Ensemble test acc: 0.7213
    Per-fold test accs: [0.6993, 0.6945, 0.6892, 0.6976, 0.695]
    Pairwise disagreement:
     [[0.     0.1981 0.2252 0.1869 0.1623]
     [0.1981 0.     0.2003 0.2069 0.2106]
     [0.2252 0.2003 0.     0.2144 0.2229]
     [0.1869 0.2069 0.2144 0.     0.1879]
     [0.1623 0.2106 0.2229 0.1879 0.    ]]
    root@autodl-container-c47345924e-ebd10b47:~/autodl-tmp# 
  • The ensemble strategies show particular effectiveness in improving generalization on unseen data

  • Pairwise disagreement analysis confirms the diversity of models generated by both K-Fold and Snapshot methods

Directory Structure

  • train.py: Optimized training pipeline with K-Fold and Snapshot ensemble
  • normal_train.py: Standard training pipeline for comparison
  • utils.py: Utility functions for training, evaluation, and model selection
  • model.py: Example model definition
  • data/: Dataset directory (created automatically)
  • checkpoints/: Directory for saved model checkpoints (created automatically)

Contact

Limitations & Future Work

  • Current limitations:

    • Primary focus on image classification tasks(more experiments are needed for different tasks)
  • Future directions:

    • Extension to object detection, segmentation, and other computer vision tasks
    • Automated hyperparameter optimization

About

OptiTrainer: Boost your deep learning model accuracy by up to 7% with an advanced, plug-and-play PyTorch training pipeline. Perfect for beginners, no complex setup needed.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published