-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcomparison.py
More file actions
190 lines (147 loc) · 6.61 KB
/
comparison.py
File metadata and controls
190 lines (147 loc) · 6.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
DATASETS = {
"LDM extra 5000 images added": "Classification_Experiments/Augmented_Dataset_vdm_2/results",
"DDPM VARIANCE extra 5000 images added": "Classification_Experiments/Augmented_Dataset_ddpm_variance_V2/results",
"mixed_data_ldm_0.2":"Classification_Experiments/mixed_Dataset_0.2/results",
"mixed_data_ldm_0.5":"Classification_Experiments/mixed_Dataset_0.5/results",
"mixed_datd_ldm_0.8":"Classification_Experiments/mixed_Dataset_0.8/results"
}
MODELS = {
"resnet18": "ResNet-18",
"convnext_tiny": "Convnext-Tiny",
"swin_t": "Swin-T",
"vit_tiny": "ViT-Tiny"
}
SAVE_DIR = "Classification_Experiments/mixed_data"
os.makedirs(SAVE_DIR, exist_ok=True)
# ================= Data Extraction =================
def extract_max_accuracies():
"""
"""
results = {}
for dataset_name, dataset_path in DATASETS.items():
dataset_results = {}
for model_id, model_name in MODELS.items():
csv_path = os.path.join(dataset_path, f"{model_id}_metrics.csv")
if not os.path.exists(csv_path):
print(f"⚠️ Warning: File not found {csv_path}")
dataset_results[model_name] = 0
continue
try:
df = pd.read_csv(csv_path)
max_acc = df['test_acc'].max() * 100 # Convert to percentage
dataset_results[model_name] = round(max_acc, 2)
print(f"✅ {dataset_name} - {model_name}: {max_acc:.2f}%")
except Exception as e:
print(f"❌ Error reading {csv_path}: {e}")
dataset_results[model_name] = 0
results[dataset_name] = dataset_results
return results
# ================= Bar Chart Generation =================
def create_comparison_bar_chart(results):
"""
Create bar chart comparing five datasets
"""
# Prepare data
model_names = list(MODELS.values())
dataset_names = list(DATASETS.keys())
# Create DataFrame for plotting
data_for_df = []
for dataset_name, model_accs in results.items():
for model_name, acc in model_accs.items():
data_for_df.append({
"Dataset": dataset_name,
"Model": model_name,
"Accuracy": acc
})
df = pd.DataFrame(data_for_df)
# Create bar chart
plt.figure(figsize=(16, 9))
# Set bar positions
x = np.arange(len(model_names)) # Model positions
width = 0.15 # Bar width (adjusted for 5 datasets)
# Plot bars for each dataset
colors = ['#3498DB', '#2ECC71', '#E74C3C', '#9B59B6', '#F39C12'] # Blue, Green, Red, Purple, Orange
bars = []
# Calculate offset to center bars
offset = (len(dataset_names) - 1) * width / 2
for i, dataset_name in enumerate(dataset_names):
dataset_accs = [results[dataset_name][model] for model in model_names]
bar = plt.bar(x + i*width - offset, dataset_accs, width,
label=dataset_name, color=colors[i], edgecolor='black')
bars.append(bar)
# Add decorations
plt.xlabel('Model', fontsize=14, fontweight='bold')
plt.ylabel('Maximum Test Accuracy (%)', fontsize=14, fontweight='bold')
plt.title('mixed_data Comparison: Maximum Test Accuracy', fontsize=16, fontweight='bold', pad=20)
plt.xticks(x, model_names, fontsize=12)
plt.legend(fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, 110) # Leave space for labels
# Add value labels on top of bars
for bar_group in bars:
for bar in bar_group:
height = bar.get_height()
plt.annotate(f'{height}%',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom', fontsize=9, fontweight='bold')
# Add grid lines
plt.grid(True, axis='y', linestyle='--', alpha=0.3)
plt.tight_layout()
# Save image
output_path = os.path.join(SAVE_DIR, "ddpm_vdm_Bar_Chart.png")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"\n📊 Bar chart saved to: {output_path}")
return df
# ================= Summary Table Generation =================
def create_summary_table(results_df):
"""
Generate summary table and save as CSV
"""
# Pivot table for better view
pivot_df = results_df.pivot(index='Model', columns='Dataset', values='Accuracy')
# Calculate improvement percentages for all augmentations
if 'Original' in pivot_df.columns:
if 'DDPM extra 5000 images added' in pivot_df.columns:
pivot_df['DDPM extra 5000 images added(%)'] = pivot_df['DDPM extra 5000 images added'] - pivot_df['LDM extra 5000 images added']
if 'mixed_data_ldm_0.2' in pivot_df.columns:
pivot_df['mixed_data_ldm_0.2 added(%)'] = pivot_df['mixed_data_ldm_0.2 added'] - pivot_df['LDM extra 5000 images added']
if 'mixed_data_ldm_0.5' in pivot_df.columns:
pivot_df['mixed_datd_ldm_0.5 added(%)'] = pivot_df['mixed_datd_ldm_0.8 added'] - pivot_df['LDM extra 5000 images added']
if 'mixed_datd_ldm_0.8' in pivot_df.columns:
pivot_df['mixed_datd_ldm_0.8 added(%)'] = pivot_df['mixed_datd_ldm_0.8 added'] - pivot_df['LDM extra 5000 images added']
# Save as CSV
csv_path = os.path.join(SAVE_DIR, "mixed_data_Comparison_Summary.csv")
pivot_df.to_csv(csv_path)
print(f"📋 Summary table saved to: {csv_path}")
# Print table
print("\n" + "="*60)
print("DDPM_VARIACNE_LDM Data Augmentation Comparison Summary")
print("="*60)
print(pivot_df.to_string())
return pivot_df
# ================= Main Function =================
def main():
print("🚀 Starting five-way data augmentation comparison analysis...")
print("="*60)
# 1. Extract data
print("📈 Extracting maximum accuracy for each model...")
results = extract_max_accuracies()
# 2. Create bar chart
print("\n🎨 Generating comparison bar chart...")
results_df = create_comparison_bar_chart(results)
# 3. Generate summary table
print("\n📋 Generating summary table...")
summary_df = create_summary_table(results_df)
print("\n" + "="*60)
print("✨ Analysis completed!")
print(f"📁 All results saved to: {SAVE_DIR}")
print("="*60)
if __name__ == "__main__":
main()