OptiTrainer is an advanced training framework designed to enhance the performance and robustness of deep learning models, with a primary focus on image classification tasks. This framework integrates multiple advanced training strategies, demonstrating significant accuracy improvements (up to 7%) on standard datasets such as CIFAR-10 compared to conventional training approaches.
- Standard Training Pipeline (
normal_train.py): A baseline implementation providing a straightforward training and validation workflow for performance comparison. - Optimized Training Pipeline (
train.py): Implements two powerful ensemble strategies:- K-Fold Cross Validation with Diversity: Enhances model generalization through stratified data splitting and diverse data augmentation across folds
- Snapshot Ensemble: Generates multiple model snapshots during a single training run with cosine annealing learning rate scheduling
- Advanced Training Components: Cosine annealing learning rate scheduling, automatic best model checkpointing, and soft-voting ensemble inference
- Customizable Architecture: Flexible design supporting user-defined models and datasets
# Clone OptiTrainer
git clone https://github.com/Jackksonns/OptiTrainer.git
cd OptiTrainer- CIFAR-10 dataset is used by default and will be automatically downloaded on first run
- For custom datasets, organize data in PyTorch-compatible format and modify the data loading section in
train.pyandnormal_train.py
-
Standard training (baseline):
python normal_train.py
-
Optimized training (K-Fold by default):
python train.py
# K-Fold implementation snippet
def get_stratified_kfold_indices(labels, n_splits=5, random_state=SEED, shuffle=True):
skf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
labels = np.array(labels)
train_idxs_list, val_idxs_list = [], []
for tr, vl in skf.split(np.zeros(len(labels)), labels):
train_idxs_list.append(tr)
val_idxs_list.append(vl)
return train_idxs_list, val_idxs_list- Stratified Splitting: Maintains class distribution across folds using
StratifiedKFoldfrom scikit-learn - Diversity Enhancement: Each fold uses different random seeds and data augmentation strategies
- Per-Fold Optimization: Independent training with automatic best model selection for each fold
- Single training run with periodic model checkpoints (snapshots) at specified intervals
- Leverages cosine annealing learning rate scheduling to encourage diverse model states
- Efficient alternative to traditional cross-validation with reduced computational overhead
# Ensemble prediction implementation
stacked = torch.stack(all_fold_preds, dim=0) # [k, N, C] (if logits, it's logits)
# Weighted averaging based on validation accuracy
if use_weighted:
w = np.array(best_acc_list, dtype=float)
w = w / w.sum()
w_t = torch.tensor(w, dtype=stacked.dtype).view(-1,1,1)
final_probs = (stacked * w_t).sum(dim=0)
else:
final_probs = torch.mean(stacked, dim=0)- Soft-Voting: Aggregates predictions from multiple models using probability averaging
- Weighted Ensemble: Option to weight models by their validation performance
- Configurable Input: Supports averaging of probabilities or logits
method: Training strategy ('kfold_diverse'or'snapshot')SEED: Random seed for reproducibilitynum_epochs: Number of training epochs per fold or for snapshot traininginit_lr: Initial learning raten_splits: Number of folds for cross-validation
use_weighted: Whether to use validation accuracy for weighted ensembleensemble_from: Input type for ensemble ('probs'or'logits')snap_interval: Interval between snapshots (for snapshot method)
-
On CIFAR-10 dataset, OptiTrainer demonstrates up to 7% higher accuracy (72% vs. 65%) compared to standard training pipelines with the same model architecture and training duration(with same init lr and epochs)
-
#mormal training process --result step: 153000, train loss: 0.7846522927284241 epoch 195 | test loss: 1.1244, test acc: 0.6547 #OptiTrainer training process --result Ensemble test acc: 0.7213 Per-fold test accs: [0.6993, 0.6945, 0.6892, 0.6976, 0.695] Pairwise disagreement: [[0. 0.1981 0.2252 0.1869 0.1623] [0.1981 0. 0.2003 0.2069 0.2106] [0.2252 0.2003 0. 0.2144 0.2229] [0.1869 0.2069 0.2144 0. 0.1879] [0.1623 0.2106 0.2229 0.1879 0. ]] root@autodl-container-c47345924e-ebd10b47:~/autodl-tmp#
-
The ensemble strategies show particular effectiveness in improving generalization on unseen data
-
Pairwise disagreement analysis confirms the diversity of models generated by both K-Fold and Snapshot methods
train.py: Optimized training pipeline with K-Fold and Snapshot ensemblenormal_train.py: Standard training pipeline for comparisonutils.py: Utility functions for training, evaluation, and model selectionmodel.py: Example model definitiondata/: Dataset directory (created automatically)checkpoints/: Directory for saved model checkpoints (created automatically)
- Author GitHub: https://github.com/Jackksonns
-
Current limitations:
- Primary focus on image classification tasks(more experiments are needed for different tasks)
-
Future directions:
- Extension to object detection, segmentation, and other computer vision tasks
- Automated hyperparameter optimization